Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ __device__ void load_bmatrix_layout(T* arr, uint32_t* R, uint32_t dimY) {
static_assert(std::is_same_v<T, __half>, "Only supported for __half types");
const int lane_id = threadIdx.x % 64;
int b_idx = ((lane_id % 4) + 4 * (lane_id / 16)) * dimY + ((lane_id % 16) / 4) * 4;
mma_impl::hip::load_fragment_4x4_half_registers<__half>(R, &arr[b_idx]);
mma_impl::hip::load_quad_transposed_fragment<__half>(R, &arr[b_idx]);
}

/// @brief Prints the four `half` values held in a thread's registers.
Expand Down
64 changes: 60 additions & 4 deletions libflashinfer/include/gpu_iface/backend/hip/mma_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,26 @@ namespace hip {

#define FLASHINFER_RUNTIME_ASSERT(x) assert(0 && x)

__device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) {
/// @brief Transposes a 4x4 matrix of `half` values held across a quad of 4 threads.
/// @details This function operates on a group of 4 consecutive threads (a quad). It assumes
/// each thread holds 4 `half` values, which together form a 4x4 matrix where each
/// thread holds one row. The function permutes these values using a series of
/// `__shfl_xor` operations so that each thread ends up holding one column of the
/// original 4x4 matrix.
///
/// Visual Representation:
/// If `[a,b,c,d]` are the 4 `half` values in Thread 0's registers:
///
/// Before: After:
/// Thread 0: [a, b, c, d] Thread 0: [a, e, i, m]
/// Thread 1: [e, f, g, h] ---> Thread 1: [b, f, j, n]
/// Thread 2: [i, j, k, l] Thread 2: [c, g, k, o]
/// Thread 3: [m, n, o, p] Thread 3: [d, h, l, p]
///
/// @note This function can be combined with `transpose_inter_quad_fragments` to perform a
/// full 16x16 in-register matrix transpose. This function handles the transposition
/// *within* each 4x4 data block.
__device__ __forceinline__ void transpose_intra_quad_fragments(uint32_t* R) {
// Calculate lane within 4-thread group
uint32_t lane_id = threadIdx.x % 64;
uint32_t lane_in_group = lane_id % 4;
Expand Down Expand Up @@ -55,6 +74,43 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) {
R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount);
}

/// @brief Permutes matrix fragments between thread quads in a wavefront to perform a block-wise
/// transpose.
/// @details This function treats the 64-thread wavefront as a 4x4 grid of thread quads.
/// Each quad (4 consecutive threads) is considered to hold a 4x4 data fragment.
/// The function transposes this 4x4 grid of fragments by swapping the register
/// contents of threads in off-diagonal quads.
///
/// Visual Representation:
/// If B(r,c) is the 4x4 data fragment held by the quad at block-row 'r' and block-col 'c':
///
/// Before: After:
/// +--------+--------+--------+--------+ +--------+--------+--------+--------+
/// | B(0,0) | B(0,1) | B(0,2) | B(0,3) | | B(0,0) | B(1,0) | B(2,0) | B(3,0) |
/// +--------+--------+--------+--------+ +--------+--------+--------+--------+
/// | B(1,0) | B(1,1) | B(1,2) | B(1,3) | ---> | B(0,1) | B(1,1) | B(2,1) | B(3,1) |
// +--------+--------+--------+--------+ +--------+--------+--------+--------+
/// | B(2,0) | B(2,1) | B(2,2) | B(2,3) | | B(0,2) | B(1,2) | B(2,2) | B(3,2) |
/// +--------+--------+--------+--------+ +--------+--------+--------+--------+
/// | B(3,0) | B(3,1) | B(3,2) | B(3,3) | | B(0,3) | B(1,3) | B(2,3) | B(3,3) |
/// +--------+--------+--------+--------+ +--------+--------+--------+--------+
///
/// @note This function can be combined with `transpose_intra_quad_fragments` (which transposes
/// the data *within* each fragment) to perform a full 16x16 in-register matrix transpose.
__device__ __forceinline__ void transpose_inter_quad_fragments(uint32_t* R) {
uint32_t lane_id = threadIdx.x % 64;

uint32_t block_row = (lane_id % 16) / 4;
uint32_t block_col = (lane_id / 16);
uint32_t thread_in_block = lane_id % 4;
uint32_t partner_lane_id = (block_row * 16) + (block_col * 4) + thread_in_block;
uint32_t xor_mask = lane_id ^ partner_lane_id;

// Exchange both registers with the partner thread
R[0] = __shfl_xor(R[0], xor_mask, 64);
R[1] = __shfl_xor(R[1], xor_mask, 64);
}

// Single unified load function for all fragment types
/// @param R [in] pointer to the register file to load the fragment into
/// @param smem_ptr [in] pointer to the shared memory to load the fragment from
Expand Down Expand Up @@ -116,10 +172,10 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
/// T2 : c g k o
/// T3 : d h l p
template <typename T>
__device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, const T* smem_ptr) {
__device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const T* smem_ptr) {
static_assert(std::is_same_v<T, __half>, "Only half type is supported");
load_fragment(R, smem_ptr);
transpose_4x4_half_registers(R);
transpose_intra_quad_fragments(R);
}

// TODO: Verify correct matrix multiplication order for rowsum on CDNA3
Expand All @@ -135,7 +191,7 @@ __device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, co
template <typename DType>
__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) {
static_assert(sizeof(DType) == 2, "DType must be 16-bit type");
transpose_4x4_half_registers(reinterpret_cast<uint32_t*>(s_frag));
transpose_intra_quad_fragments(reinterpret_cast<uint32_t*>(s_frag));
f16x4 a = reinterpret_cast<const f16x4*>(s_frag)[0];
f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)};
f32x4 c = {d[0], d[1], d[2], d[3]};
Expand Down
9 changes: 4 additions & 5 deletions libflashinfer/include/gpu_iface/mma_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) {
mma_detail::load_fragment<T>(R, smem_ptr);
}

#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__)
#if defined(PLATFORM_HIP_DEVICE)
template <typename T>
__device__ __forceinline__ void load_fragment_transpose_4x4_half_registers(uint32_t* R,
const T* smem_ptr) {
__device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const T* smem_ptr) {
static_assert(std::is_same<T, __half>::value,
"Only __half is supported for the 4x4 register transpose");
mma_detail::load_fragment_4x4_half_registers<__half>(R, smem_ptr);
"Only __half is supported for load_quad_transposed_fragment");
mma_detail::load_quad_transposed_fragment<T>(R, smem_ptr);
}
#endif

Expand Down
212 changes: 212 additions & 0 deletions libflashinfer/tests/hip/test_layout_transform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc.
//
// SPDX - License - Identifier : Apache 2.0

#include <gtest/gtest.h>
#include <stdio.h>

#include "gpu_iface/backend/hip/mma_debug_utils_hip.h"
#include "gpu_iface/backend/hip/mma_hip.h"
#include "gpu_iface/gpu_runtime_compat.hpp"

namespace {

using namespace flashinfer::gpu_iface::debug_utils::hip;

/// Kernel to test the result of load_amatrix_layout
__global__ void get_a_layout_fragments_kernel(half* output) {
uint32_t registers[2];
__shared__ half lds_array[16 * 16];

lexicographic_init_lds_array(lds_array, 16, 16);
load_amatrix_layout<__half>(lds_array, registers, 16);

const __half* values = reinterpret_cast<const __half*>(registers);
int offset = threadIdx.x * 4;
output[offset + 0] = values[0];
output[offset + 1] = values[1];
output[offset + 2] = values[2];
output[offset + 3] = values[3];
}

/// Kernel to test the result of load_bmatrix_layout
__global__ void get_b_layout_fragments_kernel(half* output) {
uint32_t registers[2];
__shared__ half lds_array[16 * 16];

// 1. Init LDS with 0..255
lexicographic_init_lds_array(lds_array, 16, 16);

// 2. Load registers using B-layout pattern
load_bmatrix_layout<__half>(lds_array, registers, 16);

// 3. Write register contents to global memory for validation
const __half* values = reinterpret_cast<const __half*>(registers);
int offset = threadIdx.x * 4;
output[offset + 0] = values[0];
output[offset + 1] = values[1];
output[offset + 2] = values[2];
output[offset + 3] = values[3];
}

/// Kernel to test the full B -> A transformation
__global__ void get_b_to_a_transform_fragments_kernel(half* output) {
uint32_t registers[2];
__shared__ half lds_array[16 * 16];

// 1. Init and load B-layout into registers
lexicographic_init_lds_array(lds_array, 16, 16);
load_bmatrix_layout<__half>(lds_array, registers, 16);

// 2. Apply the B -> A transformation
flashinfer::gpu_iface::mma_impl::hip::transpose_intra_quad_fragments(registers);
flashinfer::gpu_iface::mma_impl::hip::transpose_inter_quad_fragments(registers);

// 3. Write final register contents to global memory
const __half* values = reinterpret_cast<const __half*>(registers);
int offset = threadIdx.x * 4;
output[offset + 0] = values[0];
output[offset + 1] = values[1];
output[offset + 2] = values[2];
output[offset + 3] = values[3];
}

/// Kernel to test the full A -> B transformation
__global__ void get_a_to_b_transform_fragments_kernel(half* output) {
uint32_t registers[2];
__shared__ half lds_array[16 * 16];

// 1. Init and load A-layout into registers
lexicographic_init_lds_array(lds_array, 16, 16);
load_amatrix_layout<__half>(lds_array, registers, 16);

// 2. Apply the A -> B transformation
flashinfer::gpu_iface::mma_impl::hip::transpose_intra_quad_fragments(registers);
flashinfer::gpu_iface::mma_impl::hip::transpose_inter_quad_fragments(registers);

// 3. Write final register contents to global memory
const __half* values = reinterpret_cast<const __half*>(registers);
int offset = threadIdx.x * 4;
output[offset + 0] = values[0];
output[offset + 1] = values[1];
output[offset + 2] = values[2];
output[offset + 3] = values[3];
}

} // namespace

class LayoutTransformTest : public ::testing::Test {
protected:
void SetUp() override {
// Allocate 256 * sizeof(half) for output
FI_GPU_CALL(hipMalloc(&d_output, 256 * sizeof(half)));
h_output.resize(256);
}

void TearDown() override { FI_GPU_CALL(hipFree(d_output)); }

half* d_output;
std::vector<half> h_output;
};

TEST_F(LayoutTransformTest, LoadALayoutIsCorrect) {
get_a_layout_fragments_kernel<<<1, 64>>>(d_output);
FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost));

// On CPU, compute the expected A-layout fragments
for (int lane_id = 0; lane_id < 64; ++lane_id) {
// A-layout: T_(16*c+r) holds row r, columns 4c to 4c+3
int row = lane_id % 16;
int col_start = (lane_id / 16) * 4;
for (int i = 0; i < 4; ++i) {
float expected_val = (float)(row * 16 + col_start + i);
float gpu_val = __half2float(h_output[lane_id * 4 + i]);
EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i;
}
}
}

TEST_F(LayoutTransformTest, LoadBLayoutIsCorrect) {
// Launch kernel to get B-layout fragments
get_b_layout_fragments_kernel<<<1, 64>>>(d_output);
FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost));

// On CPU, compute the expected fragments based on the correct CDNA layout
for (int lane_id = 0; lane_id < 64; ++lane_id) {
// Correct mapping from lane_id to block coordinates
int block_row = lane_id / 16;
int block_col = (lane_id % 16) / 4;
int thread_in_block = lane_id % 4;

// The B-Layout means each thread holds a column fragment.
// The fragment's column index is determined by the thread's block_col.
// The fragment's starting row is determined by the thread's block_row.
// The element within the fragment is determined by the thread_in_block.
//
// Example 1, T0 (lane_id=0): br=0, bc=0, tib=0. It holds column 0, elements 0-3.
// Expected: [M(0,0), M(1,0), M(2,0), M(3,0)] -> [0, 16, 32, 48]
//
// Example 2, T17 (lane_id=17): br=1, bc=0, tib=1. It holds column 1, elements 4-7.
// Expected: [M(4,1), M(5,1), M(6,1), M(7,1)] -> [65, 81, 97, 113]
for (int i = 0; i < 4; ++i) {
float expected_val = (float)((block_col * 4 + i) * 16 + (block_row * 4 + thread_in_block));

// Let's use the original correct logic with the corrected variable names.
int orig_matrix_col = block_col * 4 + i;
int orig_matrix_row = block_row * 4 + thread_in_block;
expected_val = (float)(orig_matrix_row * 16 + orig_matrix_col);

// The B-layout fragment for a thread is a column segment.
// T0 gets column 0, elements 0-3 -> [0, 16, 32, 48]
// T1 gets column 1, elements 0-3 -> [1, 17, 33, 49]
// T16 gets column 0, elements 4-7 -> [64, 80, 96, 112]
// T17 gets column 1, elements 4-7 -> [65, 81, 97, 113]
int frag_col_idx = block_col * 4 + thread_in_block;
int frag_row_start = block_row * 4;

expected_val = (frag_row_start + i) * 16 + frag_col_idx;

float gpu_val = __half2float(h_output[lane_id * 4 + i]);
EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i;
}
}
}

TEST_F(LayoutTransformTest, TransformAtoBIsCorrect) {
get_a_to_b_transform_fragments_kernel<<<1, 64>>>(d_output);
FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost));

// The expected result is the B-layout fragment, same as the LoadBLayoutIsCorrect test
for (int lane_id = 0; lane_id < 64; ++lane_id) {
int block_row = lane_id / 16;
int block_col = (lane_id % 16) / 4;
int thread_in_block = lane_id % 4;
int frag_col_idx = block_col * 4 + thread_in_block;
int frag_row_start = block_row * 4;

for (int i = 0; i < 4; ++i) {
float expected_val = (float)((frag_row_start + i) * 16 + frag_col_idx);
float gpu_val = __half2float(h_output[lane_id * 4 + i]);
EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i;
}
}
}

TEST_F(LayoutTransformTest, TransformBtoAIsCorrect) {
// Launch kernel to get transformed fragments
get_b_to_a_transform_fragments_kernel<<<1, 64>>>(d_output);
FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, 256 * sizeof(half), hipMemcpyDeviceToHost));

// On CPU, compute the expected A-layout fragments
for (int lane_id = 0; lane_id < 64; ++lane_id) {
// T0 gets {0,1,2,3}, T1 gets {16,17,18,19}, etc.
int row = lane_id % 16;
int col_start = (lane_id / 16) * 4;

for (int i = 0; i < 4; ++i) {
float expected_val = (float)(row * 16 + col_start + i);
float gpu_val = __half2float(h_output[lane_id * 4 + i]);
EXPECT_EQ(gpu_val, expected_val) << "Mismatch at lane " << lane_id << ", element " << i;
}
}
}
2 changes: 1 addition & 1 deletion libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ __global__ void test_mfma_kernel(const __half* A, const __half* B, float* C) {
int b_idx = ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4;

flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]);
flashinfer::gpu_iface::mma_impl::hip::load_fragment_4x4_half_registers<__half>(b_reg, &B[b_idx]);
flashinfer::gpu_iface::mma::load_quad_transposed_fragment<__half>(b_reg, &B[b_idx]);
flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(c_reg, a_reg, b_reg);

for (int i = 0; i < 4; ++i) {
Expand Down