Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ using namespace cute;

typedef uint32_t __nv_fp4x8_storage_t;
typedef uint32_t __nv_bf16x2_storage_t;
typedef uint32_t __nv_int4x8_storage_t;
typedef uint64_t __nv_fp8x8_storage_t;
typedef cutlass::uint128_t __nv_bf16x8_storage_t;

constexpr int int4_group_size = 128;
Expand All @@ -50,53 +52,97 @@ inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code)
return res;
}

__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index)
__constant__ static __nv_fp8x4_storage_t HIGH_E4M3s_LUT_[2] = {0x03020100U, 0x03020100U};
__constant__ static __nv_fp8x4_storage_t LOW_E4M3s_LUT_[2] = {0xFFFEFC00U, 0xFFFEFC00U};

__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_fp4_to_bf16(unsigned const index)
{
const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654
const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210

auto lane_id = threadIdx.x & 0x1;
__nv_fp8x4_storage_t h4b_lut = HIGH_E4M3s_LUT_[lane_id];
__nv_fp8x4_storage_t l4b_lut = LOW_E4M3s_LUT_[lane_id];

__nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index);

return lut_res;
}

__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8)
__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(
const __nv_fp4x8_storage_t fp4x8)
{
__nv_bf16x8_storage_t bf16x8_raw = {0, 0};
__nv_bf16x8_storage_t bf16x8_raw;
__nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw);

unsigned zero_padding = 0x00000000U;
__nv_fp8x4_storage_t h_fp8x4_0to1_bits = (fp4x8 & 0xC0C0C0C0U) >> 6; // 7632
__nv_fp8x4_storage_t l_fp8x4_0to1_bits = (fp4x8 & 0x0C0C0C0CU) >> 2; // 5410

unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U;
unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U);

__nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654
__nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210
__nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_fp4_to_bf16(h4b_em_fp4x4); // 7564
__nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_fp4_to_bf16(l4b_em_fp4x4); // 3120

bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0
bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2
bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4
bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6
bf16x2_raw[0] = prmt(l_fp8x4_0to1_bits, l4b_2to9_bits, 0x5240U) << 6U; // 1 0
bf16x2_raw[1] = prmt(h_fp8x4_0to1_bits, l4b_2to9_bits, 0x5341U) << 6U; // 3 2

__nv_bf16x2_storage_t bf16x2_0to1_bits;
bf16x2_raw[2] = prmt(l_fp8x4_0to1_bits, h4b_2to9_bits, 0x7260U) << 6U; // 5 4
bf16x2_raw[3] = prmt(h_fp8x4_0to1_bits, h4b_2to9_bits, 0x7361U) << 6U; // 7 6

__nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1
__nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0
return bf16x8_raw;
}

bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0
bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits;
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2
bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits;
// [ 0, 1, 2, 3] encoded as FP8
__constant__ static uint32_t POS_E4M3s_REG1_[2] = {0x44403800, 0x44403800};
// [ 4, 5, 6, 7] encoded as FP8
__constant__ static uint32_t POS_E4M3s_REG2_[2] = {0x4E4C4A48, 0x4E4C4A48};
// [-8, -7, -6, -5] encoded as FP8
__constant__ static uint32_t NEG_E4M3s_REG1_[2] = {0xCACCCED0, 0xCACCCED0};
// [-4, -3, -2, -1] encoded as FP8
__constant__ static uint32_t NEG_E4M3s_REG2_[2] = {0xB8C0C4C8, 0xB8C0C4C8};

h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5
l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4
__device__ __inline__ __nv_fp8x8_storage_t psx_cvt_lut_prmt_int4x8_to_fp8x8(const __nv_int4x8_storage_t int4x8)
{
__nv_fp8x8_storage_t fp8x8_raw;
__nv_fp8x4_storage_t* fp8x4_raw = reinterpret_cast<__nv_fp8x4_storage_t*>(&fp8x8_raw);

bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4
bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits;
bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6
bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits;
// View the input as reg
uint32_t reg = reinterpret_cast<uint32_t const&>(int4x8);

return bf16x8_raw;
// Determines if to get from the signed or unsigned candidates
uint32_t sign = (reg & 0x88888888) >> 1;

// Ignore sign bit when indexing into LUT
uint32_t lut_idx = (reg & 0x77777777);

// Signed is OR'd with 0x32103210 to find the correct value in the LUT
const uint32_t final_prmt_base = 0x32103210;

auto lane_id = threadIdx.x & 0x1;
uint32_t POS_E4M3s_REG1 = POS_E4M3s_REG1_[lane_id];
uint32_t POS_E4M3s_REG2 = POS_E4M3s_REG2_[lane_id];
uint32_t NEG_E4M3s_REG1 = NEG_E4M3s_REG1_[lane_id];
uint32_t NEG_E4M3s_REG2 = NEG_E4M3s_REG2_[lane_id];

asm volatile(
"{\n"
" .reg .b32 pos_f8s, neg_f8s;\n"
" .reg .b32 lut1, sign1, prmt0, prmt1;\n"
" or.b32 prmt0, %4, %3;\n"
" prmt.b32 pos_f8s, %5, %6, %2;\n"
" prmt.b32 neg_f8s, %7, %8, %2;\n"
" prmt.b32 %0, pos_f8s, neg_f8s, prmt0;\n"
" shr.u32 lut1, %2, 16;\n"
" shr.u32 sign1, %3, 16;\n"
" or.b32 prmt1, %4, sign1;\n"
" prmt.b32 pos_f8s, %5, %6, lut1;\n"
" prmt.b32 neg_f8s, %7, %8, lut1;\n"
" prmt.b32 %1, pos_f8s, neg_f8s, prmt1;\n"
"}\n"
: "=r"(fp8x4_raw[0]), "=r"(fp8x4_raw[1])
: "r"(lut_idx), "r"(sign), "r"(final_prmt_base), "r"(POS_E4M3s_REG1), "r"(POS_E4M3s_REG2), "r"(NEG_E4M3s_REG1),
"r"(NEG_E4M3s_REG2));

return fp8x8_raw;
}

template <class Collective>
Expand All @@ -119,6 +165,7 @@ struct MixedGroupedGemmInputUtils
static constexpr auto ModeHasScales = Collective::ModeHasScales;
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable;
static constexpr auto UseInt4ToFP8LookupTable = Collective::UseInt4ToFP8LookupTable;

public:
static constexpr auto elements_per_smem_scale()
Expand Down Expand Up @@ -205,14 +252,70 @@ struct MixedGroupedGemmInputUtils
}
}

/// Utilities to copy A from smem to RF
template <class SmemTiledCopyA, class TensorASmemView, class TensorACopyView>
CUTLASS_DEVICE static void copy_tensors_A(SmemTiledCopyA const& smem_tiled_copy_A, TensorASmemView const& tCsA,
TensorACopyView& tCrA_copy_view, int k_block, int read_stage)
{

if (k_block < size<2>(tCsA.shape()))
{
copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
}
}

/// Utilities to copy Scales for A from smem to RF
template <class... Ts, class... Us>
CUTLASS_DEVICE static void copy_tensors_SFA(cute::tuple<Ts...> const& partitioned_mma_extra_info,
cute::tuple<Us...> const& tiled_copy_and_views, int k_block, int read_stage)
{

// We are starting a new k-tile so copy the scale
if constexpr (KernelConversionMode == ConversionMode::DirectConvert)
{
// nothing to do
}
else if constexpr (ModeHasScales)
{
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
auto tCsS = cute::get<0>(partitioned_mma_extra_info);

copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
{
// Nothing extra to do
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block));
}
else
{
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
else
{
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}

/// Utilities to copy A and extra inputs from smem to RF
template <class SmemTiledCopyA, class TensorASmemView, class TensorACopyView, class... Ts, class... Us>
CUTLASS_DEVICE static void copy_tensors_MK(SmemTiledCopyA const& smem_tiled_copy_A, TensorASmemView const& tCsA,
TensorACopyView& tCrA_copy_view, cute::tuple<Ts...> const& partitioned_mma_extra_info,
cute::tuple<Us...> const& tiled_copy_and_views, int k_block, int read_stage)
{

copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
if (k_block < size<2>(tCsA.shape()))
{
copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
}

if (k_block == 0)
{
Expand Down Expand Up @@ -312,7 +415,6 @@ struct MixedGroupedGemmInputUtils
}
}

// The core converter uses a lookup table to converts i4 -> 8 bit value.
template <class EngineIn, class LayoutIn, class EngineOut,
class LayoutOut>
CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries
Expand All @@ -330,7 +432,27 @@ struct MixedGroupedGemmInputUtils
auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0);
auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0);

dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_);
dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(src_);
}

template <class EngineIn, class LayoutIn, class EngineOut,
class LayoutOut>
CUTLASS_DEVICE static void int4tofp8_lookup_table_convert( // Accept mutable temporaries
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>&& dst)
{
int4tofp8_lookup_table_convert(src, dst);
}

template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut>
CUTLASS_DEVICE static void int4tofp8_lookup_table_convert(
Tensor<EngineIn, LayoutIn> const& src, Tensor<EngineOut, LayoutOut>& dst)
{

// View the input as reg
auto&& src_ = cute::recast<__nv_int4x8_storage_t>(src)(0);
auto&& dst_ = cute::recast<__nv_fp8x8_storage_t>(dst)(0);

dst_ = psx_cvt_lut_prmt_int4x8_to_fp8x8(src_);
}

/// Utilities to dequantize A.
Expand Down Expand Up @@ -535,6 +657,10 @@ struct MixedGroupedGemmInputUtils
{
fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i));
}
else if constexpr (UseInt4ToFP8LookupTable)
{
int4tofp8_lookup_table_convert(src_vm(_, i), dst_vm(_, i));
}
else
{
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
Expand Down
Loading