diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h index 7d96544cc1..cf1aa52629 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h @@ -57,7 +57,7 @@ __device__ void load_bmatrix_layout(T* arr, uint32_t* R, uint32_t dimY) { static_assert(std::is_same_v, "Only supported for __half types"); const int lane_id = threadIdx.x % 64; int b_idx = ((lane_id % 4) + 4 * (lane_id / 16)) * dimY + ((lane_id % 16) / 4) * 4; - mma_impl::hip::load_fragment_4x4_half_registers<__half>(R, &arr[b_idx]); + mma_impl::hip::load_quad_transposed_fragment<__half>(R, &arr[b_idx]); } /// @brief Prints the four `half` values held in a thread's registers. diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 10a491bbc4..1316e454ce 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -21,7 +21,26 @@ namespace hip { #define FLASHINFER_RUNTIME_ASSERT(x) assert(0 && x) -__device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) { +/// @brief Transposes a 4x4 matrix of `half` values held across a quad of 4 threads. +/// @details This function operates on a group of 4 consecutive threads (a quad). It assumes +/// each thread holds 4 `half` values, which together form a 4x4 matrix where each +/// thread holds one row. The function permutes these values using a series of +/// `__shfl_xor` operations so that each thread ends up holding one column of the +/// original 4x4 matrix. +/// +/// Visual Representation: +/// If `[a,b,c,d]` are the 4 `half` values in Thread 0's registers: +/// +/// Before: After: +/// Thread 0: [a, b, c, d] Thread 0: [a, e, i, m] +/// Thread 1: [e, f, g, h] ---> Thread 1: [b, f, j, n] +/// Thread 2: [i, j, k, l] Thread 2: [c, g, k, o] +/// Thread 3: [m, n, o, p] Thread 3: [d, h, l, p] +/// +/// @note This function can be combined with `transpose_inter_quad_fragments` to perform a +/// full 16x16 in-register matrix transpose. This function handles the transposition +/// *within* each 4x4 data block. +__device__ __forceinline__ void transpose_intra_quad_fragments(uint32_t* R) { // Calculate lane within 4-thread group uint32_t lane_id = threadIdx.x % 64; uint32_t lane_in_group = lane_id % 4; @@ -55,6 +74,43 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) { R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); } +/// @brief Permutes matrix fragments between thread quads in a wavefront to perform a block-wise +/// transpose. +/// @details This function treats the 64-thread wavefront as a 4x4 grid of thread quads. +/// Each quad (4 consecutive threads) is considered to hold a 4x4 data fragment. +/// The function transposes this 4x4 grid of fragments by swapping the register +/// contents of threads in off-diagonal quads. +/// +/// Visual Representation: +/// If B(r,c) is the 4x4 data fragment held by the quad at block-row 'r' and block-col 'c': +/// +/// Before: After: +/// +--------+--------+--------+--------+ +--------+--------+--------+--------+ +/// | B(0,0) | B(0,1) | B(0,2) | B(0,3) | | B(0,0) | B(1,0) | B(2,0) | B(3,0) | +/// +--------+--------+--------+--------+ +--------+--------+--------+--------+ +/// | B(1,0) | B(1,1) | B(1,2) | B(1,3) | ---> | B(0,1) | B(1,1) | B(2,1) | B(3,1) | +// +--------+--------+--------+--------+ +--------+--------+--------+--------+ +/// | B(2,0) | B(2,1) | B(2,2) | B(2,3) | | B(0,2) | B(1,2) | B(2,2) | B(3,2) | +/// +--------+--------+--------+--------+ +--------+--------+--------+--------+ +/// | B(3,0) | B(3,1) | B(3,2) | B(3,3) | | B(0,3) | B(1,3) | B(2,3) | B(3,3) | +/// +--------+--------+--------+--------+ +--------+--------+--------+--------+ +/// +/// @note This function can be combined with `transpose_intra_quad_fragments` (which transposes +/// the data *within* each fragment) to perform a full 16x16 in-register matrix transpose. +__device__ __forceinline__ void transpose_inter_quad_fragments(uint32_t* R) { + uint32_t lane_id = threadIdx.x % 64; + + uint32_t block_row = (lane_id % 16) / 4; + uint32_t block_col = (lane_id / 16); + uint32_t thread_in_block = lane_id % 4; + uint32_t partner_lane_id = (block_row * 16) + (block_col * 4) + thread_in_block; + uint32_t xor_mask = lane_id ^ partner_lane_id; + + // Exchange both registers with the partner thread + R[0] = __shfl_xor(R[0], xor_mask, 64); + R[1] = __shfl_xor(R[1], xor_mask, 64); +} + // Single unified load function for all fragment types /// @param R [in] pointer to the register file to load the fragment into /// @param smem_ptr [in] pointer to the shared memory to load the fragment from @@ -116,10 +172,10 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u /// T2 : c g k o /// T3 : d h l p template -__device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, const T* smem_ptr) { +__device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const T* smem_ptr) { static_assert(std::is_same_v, "Only half type is supported"); load_fragment(R, smem_ptr); - transpose_4x4_half_registers(R); + transpose_intra_quad_fragments(R); } // TODO: Verify correct matrix multiplication order for rowsum on CDNA3 @@ -135,7 +191,7 @@ __device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, co template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) { static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); - transpose_4x4_half_registers(reinterpret_cast(s_frag)); + transpose_intra_quad_fragments(reinterpret_cast(s_frag)); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; f32x4 c = {d[0], d[1], d[2], d[3]}; diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index 6df297a96e..b015b116a7 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -33,13 +33,12 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { mma_detail::load_fragment(R, smem_ptr); } -#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__) +#if defined(PLATFORM_HIP_DEVICE) template -__device__ __forceinline__ void load_fragment_transpose_4x4_half_registers(uint32_t* R, - const T* smem_ptr) { +__device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const T* smem_ptr) { static_assert(std::is_same::value, - "Only __half is supported for the 4x4 register transpose"); - mma_detail::load_fragment_4x4_half_registers<__half>(R, smem_ptr); + "Only __half is supported for load_quad_transposed_fragment"); + mma_detail::load_quad_transposed_fragment(R, smem_ptr); } #endif diff --git a/libflashinfer/tests/hip/test_layout_transform.cpp b/libflashinfer/tests/hip/test_layout_transform.cpp new file mode 100644 index 0000000000..0c9c4da93c --- /dev/null +++ b/libflashinfer/tests/hip/test_layout_transform.cpp @@ -0,0 +1,212 @@ +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#include +#include + +#include "gpu_iface/backend/hip/mma_debug_utils_hip.h" +#include "gpu_iface/backend/hip/mma_hip.h" +#include "gpu_iface/gpu_runtime_compat.hpp" + +namespace { + +using namespace flashinfer::gpu_iface::debug_utils::hip; + +/// Kernel to test the result of load_amatrix_layout +__global__ void get_a_layout_fragments_kernel(half* output) { + uint32_t registers[2]; + __shared__ half lds_array[16 * 16]; + + lexicographic_init_lds_array(lds_array, 16, 16); + load_amatrix_layout<__half>(lds_array, registers, 16); + + const __half* values = reinterpret_cast(registers); + int offset = threadIdx.x * 4; + output[offset + 0] = values[0]; + output[offset + 1] = values[1]; + output[offset + 2] = values[2]; + output[offset + 3] = values[3]; +} + +/// Kernel to test the result of load_bmatrix_layout +__global__ void get_b_layout_fragments_kernel(half* output) { + uint32_t registers[2]; + __shared__ half lds_array[16 * 16]; + + // 1. Init LDS with 0..255 + lexicographic_init_lds_array(lds_array, 16, 16); + + // 2. Load registers using B-layout pattern + load_bmatrix_layout<__half>(lds_array, registers, 16); + + // 3. Write register contents to global memory for validation + const __half* values = reinterpret_cast(registers); + int offset = threadIdx.x * 4; + output[offset + 0] = values[0]; + output[offset + 1] = values[1]; + output[offset + 2] = values[2]; + output[offset + 3] = values[3]; +} + +/// Kernel to test the full B -> A transformation +__global__ void get_b_to_a_transform_fragments_kernel(half* output) { + uint32_t registers[2]; + __shared__ half lds_array[16 * 16]; + + // 1. Init and load B-layout into registers + lexicographic_init_lds_array(lds_array, 16, 16); + load_bmatrix_layout<__half>(lds_array, registers, 16); + + // 2. Apply the B -> A transformation + flashinfer::gpu_iface::mma_impl::hip::transpose_intra_quad_fragments(registers); + flashinfer::gpu_iface::mma_impl::hip::transpose_inter_quad_fragments(registers); + + // 3. Write final register contents to global memory + const __half* values = reinterpret_cast(registers); + int offset = threadIdx.x * 4; + output[offset + 0] = values[0]; + output[offset + 1] = values[1]; + output[offset + 2] = values[2]; + output[offset + 3] = values[3]; +} + +/// Kernel to test the full A -> B transformation +__global__ void get_a_to_b_transform_fragments_kernel(half* output) { + uint32_t registers[2]; + __shared__ half lds_array[16 * 16]; + + // 1. Init and load A-layout into registers + lexicographic_init_lds_array(lds_array, 16, 16); + load_amatrix_layout<__half>(lds_array, registers, 16); + + // 2. Apply the A -> B transformation + flashinfer::gpu_iface::mma_impl::hip::transpose_intra_quad_fragments(registers); + flashinfer::gpu_iface::mma_impl::hip::transpose_inter_quad_fragments(registers); + + // 3. Write final register contents to global memory + const __half* values = reinterpret_cast(registers); + int offset = threadIdx.x * 4; + output[offset + 0] = values[0]; + output[offset + 1] = values[1]; + output[offset + 2] = values[2]; + output[offset + 3] = values[3]; +} + +} // namespace + +class LayoutTransformTest : public ::testing::Test { + protected: + void SetUp() override { + // Allocate 256 * sizeof(half) for output + FI_GPU_CALL(hipMalloc(&d_output, 256 * sizeof(half))); + h_output.resize(256); + } + + void TearDown() override { FI_GPU_CALL(hipFree(d_output)); } + + half* d_output; + std::vector h_output; +}; + +TEST_F(LayoutTransformTest, LoadALayoutIsCorrect) { + get_a_layout_fragments_kernel<<<1, 64>>>(d_output); + FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost)); + + // On CPU, compute the expected A-layout fragments + for (int lane_id = 0; lane_id < 64; ++lane_id) { + // A-layout: T_(16*c+r) holds row r, columns 4c to 4c+3 + int row = lane_id % 16; + int col_start = (lane_id / 16) * 4; + for (int i = 0; i < 4; ++i) { + float expected_val = (float)(row * 16 + col_start + i); + float gpu_val = __half2float(h_output[lane_id * 4 + i]); + EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i; + } + } +} + +TEST_F(LayoutTransformTest, LoadBLayoutIsCorrect) { + // Launch kernel to get B-layout fragments + get_b_layout_fragments_kernel<<<1, 64>>>(d_output); + FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost)); + + // On CPU, compute the expected fragments based on the correct CDNA layout + for (int lane_id = 0; lane_id < 64; ++lane_id) { + // Correct mapping from lane_id to block coordinates + int block_row = lane_id / 16; + int block_col = (lane_id % 16) / 4; + int thread_in_block = lane_id % 4; + + // The B-Layout means each thread holds a column fragment. + // The fragment's column index is determined by the thread's block_col. + // The fragment's starting row is determined by the thread's block_row. + // The element within the fragment is determined by the thread_in_block. + // + // Example 1, T0 (lane_id=0): br=0, bc=0, tib=0. It holds column 0, elements 0-3. + // Expected: [M(0,0), M(1,0), M(2,0), M(3,0)] -> [0, 16, 32, 48] + // + // Example 2, T17 (lane_id=17): br=1, bc=0, tib=1. It holds column 1, elements 4-7. + // Expected: [M(4,1), M(5,1), M(6,1), M(7,1)] -> [65, 81, 97, 113] + for (int i = 0; i < 4; ++i) { + float expected_val = (float)((block_col * 4 + i) * 16 + (block_row * 4 + thread_in_block)); + + // Let's use the original correct logic with the corrected variable names. + int orig_matrix_col = block_col * 4 + i; + int orig_matrix_row = block_row * 4 + thread_in_block; + expected_val = (float)(orig_matrix_row * 16 + orig_matrix_col); + + // The B-layout fragment for a thread is a column segment. + // T0 gets column 0, elements 0-3 -> [0, 16, 32, 48] + // T1 gets column 1, elements 0-3 -> [1, 17, 33, 49] + // T16 gets column 0, elements 4-7 -> [64, 80, 96, 112] + // T17 gets column 1, elements 4-7 -> [65, 81, 97, 113] + int frag_col_idx = block_col * 4 + thread_in_block; + int frag_row_start = block_row * 4; + + expected_val = (frag_row_start + i) * 16 + frag_col_idx; + + float gpu_val = __half2float(h_output[lane_id * 4 + i]); + EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i; + } + } +} + +TEST_F(LayoutTransformTest, TransformAtoBIsCorrect) { + get_a_to_b_transform_fragments_kernel<<<1, 64>>>(d_output); + FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost)); + + // The expected result is the B-layout fragment, same as the LoadBLayoutIsCorrect test + for (int lane_id = 0; lane_id < 64; ++lane_id) { + int block_row = lane_id / 16; + int block_col = (lane_id % 16) / 4; + int thread_in_block = lane_id % 4; + int frag_col_idx = block_col * 4 + thread_in_block; + int frag_row_start = block_row * 4; + + for (int i = 0; i < 4; ++i) { + float expected_val = (float)((frag_row_start + i) * 16 + frag_col_idx); + float gpu_val = __half2float(h_output[lane_id * 4 + i]); + EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i; + } + } +} + +TEST_F(LayoutTransformTest, TransformBtoAIsCorrect) { + // Launch kernel to get transformed fragments + get_b_to_a_transform_fragments_kernel<<<1, 64>>>(d_output); + FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost)); + + // On CPU, compute the expected A-layout fragments + for (int lane_id = 0; lane_id < 64; ++lane_id) { + // T0 gets {0,1,2,3}, T1 gets {16,17,18,19}, etc. + int row = lane_id % 16; + int col_start = (lane_id / 16) * 4; + + for (int i = 0; i < 4; ++i) { + float expected_val = (float)(row * 16 + col_start + i); + float gpu_val = __half2float(h_output[lane_id * 4 + i]); + EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i; + } + } +} diff --git a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp index 56c8f9391d..fd638a75d0 100644 --- a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp +++ b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp @@ -57,7 +57,7 @@ __global__ void test_mfma_kernel(const __half* A, const __half* B, float* C) { int b_idx = ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4; flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); - flashinfer::gpu_iface::mma_impl::hip::load_fragment_4x4_half_registers<__half>(b_reg, &B[b_idx]); + flashinfer::gpu_iface::mma::load_quad_transposed_fragment<__half>(b_reg, &B[b_idx]); flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(c_reg, a_reg, b_reg); for (int i = 0; i < 4; ++i) {