diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 465241546d..33fca4fd8f 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -991,12 +991,12 @@ __device__ auto quantizePackedFPXValue( if constexpr (is_fp8) { return [](PackedVec& vec, float /* ignored */, uint8_t* SFout) -> uint64_t { static_assert(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == VecSize); - return cvt_warp_fp16_to_mxfp8(vec, SFout); + return cvt_warp_fp16_to_mxfp8(vec, SFout); }; } else { return (scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4) - ? &cvt_warp_fp16_to_fp4 - : &cvt_warp_fp16_to_fp4; + ? &cvt_warp_fp16_to_fp4 + : &cvt_warp_fp16_to_fp4; } }(); diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index ca1bb31acd..92598e5bba 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include +#include #include #include "tensorrt_llm/common/assert.h" @@ -93,6 +95,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl, cudaStream_t stream) { // Fixed SF_VEC_SIZE as 32 + // TODO: TMA quantization for MXFP8 is not supported yet because of SF_VEC_SIZE = 32. static constexpr int SF_VEC_SIZE = 32; // Grid, Block size. @@ -178,9 +181,136 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3); #endif #endif +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TMA tensor map creation helpers + +template +CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3], + uint64_t stride_in_bytes[2], uint32_t smem_dim[3], + CUtensorMapSwizzle swizzle_type) { + CUtensorMap tensor_map{}; + constexpr uint32_t rank = 3; + uint32_t elem_strides[rank] = {1, 1, 1}; + + // Get pointer to cuTensorMapEncodeTiled + cudaDriverEntryPointQueryResult driver_status; + void* cuTensorMapEncodeTiled_ptr = nullptr; + +#if CUDA_VERSION >= 12050 + cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, + cudaEnableDefault, &driver_status); +#else + cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, cudaEnableDefault, + &driver_status); +#endif + + if (driver_status != cudaDriverEntryPointSuccess) { + TLLM_CHECK_WITH_INFO(false, "Failed to get cuTensorMapEncodeTiled entry point"); + } + + auto encode_func = + reinterpret_cast(cuTensorMapEncodeTiled_ptr); + + CUtensorMapDataType data_type; + if constexpr (std::is_same_v) { + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if constexpr (std::is_same_v) { + data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if constexpr (std::is_same_v) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + + CUresult result = + encode_func(&tensor_map, data_type, rank, global_address, gmem_dim, stride_in_bytes, smem_dim, + elem_strides, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA); + TLLM_CHECK_WITH_INFO(result == CUDA_SUCCESS, "Failed to encode TMA tensor map"); + return tensor_map; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // FP4/MXFP8 Quantization +// Helper function to launch TMA quantization kernel +template +void launchFP4QuantizationTma(int b, int m, int n, T const* input, float const* SFScale, + int64_t* output, int32_t* SFOuput, bool useUE8M0, + QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl, + cudaStream_t stream) { + using Traits = TmaKernelTraits; + constexpr int TMA_ROW_TILE = Traits::TMA_ROW_TILE; + constexpr int TMA_COL_TILE = Traits::TMA_COL_TILE; + constexpr int NUM_CONSUMER_WARPS = 8; + + // Compute effective rows for swizzled layouts + int effectiveRows = computeEffectiveRows(m, layout); + + // Grid and block configuration for TMA kernel + // TMA kernel uses 288 threads: 1 producer warp + 8 consumer warps + dim3 block(288); + // Each block handles TMA_ROW_TILE rows + int numRowTiles = (effectiveRows + TMA_ROW_TILE - 1) / TMA_ROW_TILE; + dim3 grid(std::min(numRowTiles, multiProcessorCount * 2)); + + // Dynamic shared memory size + size_t smem_size = get_tma_smem_size(); + + // Create 3D TMA tensor map descriptor + // The TMA kernel loads a box of [TMA_COL_TILE, TMA_ROW_TILE, NUM_CONSUMER_WARPS] elements per TMA + // call Global tensor is treated as [TMA_COL_TILE, B*M, num_tiles] where num_tiles = N / + // TMA_COL_TILE. We use b * m (not b * effectiveRows) because batches are stored contiguously + // without padding between them. + int num_col_tiles = (n + TMA_COL_TILE - 1) / TMA_COL_TILE; + uint64_t gmem_dim[3] = { + static_cast(TMA_COL_TILE), // Elements per tile (contiguous in memory) + static_cast(b * m), // Total rows across all batches + static_cast(num_col_tiles) // Number of column tiles + }; + uint64_t stride_in_bytes[2] = { + static_cast(n * sizeof(T)), // Stride between rows (in bytes) + static_cast(TMA_COL_TILE * sizeof(T)) // Stride between tiles (in bytes) + }; + uint32_t smem_dim[3] = { + static_cast(TMA_COL_TILE), // Elements loaded per tile + static_cast(TMA_ROW_TILE), // Rows loaded per TMA call + static_cast(NUM_CONSUMER_WARPS) // Number of tiles loaded (for 8 consumer warps) + }; + + // CUtensorMap must be 64-byte aligned + // Use SWIZZLE_128B for half/bf16 (2-byte types), SWIZZLE_NONE for FP8 (1-byte types) + constexpr CUtensorMapSwizzle swizzle_type = + (std::is_same_v || std::is_same_v) + ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B + : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + alignas(64) CUtensorMap tensor_map = make_3d_tma_copy_desc( + const_cast(input), gmem_dim, stride_in_bytes, smem_dim, swizzle_type); + + // Select and launch the TMA kernel + auto* kernel_instance = + useUE8M0 ? &quantize_with_block_size_tma + : &quantize_with_block_size_tma; + + // Set max dynamic shared memory for the kernel (required for > 48KB) + cudaFuncSetAttribute(kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + cudaLaunchConfig_t config; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, + reinterpret_cast(output), reinterpret_cast(SFOuput), + layout, tensor_map); +} + template void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, @@ -188,6 +318,18 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS cudaStream_t stream) { #ifdef ENABLE_FP8 if constexpr (std::is_same_v) { + // Use TMA kernel for large m (high throughput mode) + // TODO: fix the issue when n is not a multiple of NUM_CONSUMER_WARPS * TMA_COL_TILE + constexpr int TMA_COL_CHUNK = 8 * 64; // NUM_CONSUMER_WARPS * TMA_COL_TILE + if constexpr (SF_VEC_SIZE == 16) { + if (SF_VEC_SIZE == 16 && m >= 1024 && n % TMA_COL_CHUNK == 0) { + launchFP4QuantizationTma( + b, m, n, input, SFScale, output, SFOuput, useUE8M0, layout, multiProcessorCount, + enable_pdl, stream); + return; + } + } + // Original non-TMA path for small m or SF_VEC_SIZE != 16 // Grid, Block size. // Each thread converts 16 values. dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); @@ -205,10 +347,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS kernel_instance<<>>(b, m, n, n, input, SFScale, reinterpret_cast(output), reinterpret_cast(SFOuput), layout); - } else #endif { + // Use TMA kernel for large m (high throughput mode) + // TODO: fix the issue when n is not a multiple of NUM_CONSUMER_WARPS * TMA_COL_TILE + constexpr int TMA_COL_CHUNK = 8 * 64; // NUM_CONSUMER_WARPS * TMA_COL_TILE + if constexpr (SF_VEC_SIZE == 16) { + if (SF_VEC_SIZE == 16 && m >= 1024 && n % TMA_COL_CHUNK == 0) { + launchFP4QuantizationTma( + b, m, n, input, SFScale, output, SFOuput, useUE8M0, layout, multiProcessorCount, + enable_pdl, stream); + return; + } + } + // Original non-TMA path for small m or SF_VEC_SIZE != 16 // Grid, Block size. // Each thread converts 8 values. dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh index 7abf2eb631..72ba3a812d 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh @@ -14,16 +14,21 @@ * limitations under the License. */ +#include #include +#include + #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaTypeUtils.cuh" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/quantTypeUtils.cuh" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/kernels/quantization_utils.cuh" using namespace tensorrt_llm::common; +using Barrier = cutlass::arch::ClusterTransactionBarrier; namespace tensorrt_llm { namespace kernels { @@ -88,88 +93,6 @@ __global__ static void quantizedKernel(char4* dst, __nv_bfloat162 const* src, } #endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct DstVec { - static_assert("not implemented."); -}; - -template <> -struct DstVec { - using Type = uint32_t; -}; - -template <> -struct DstVec { - using Type = uint2; -}; - -#ifdef ENABLE_BF16 - -template <> -struct DstVec<__nv_bfloat162, 4> { - using Type = uint2; -}; - -#endif // ENABLE_BF16 - -template -struct DstVec { - static_assert(sizeof(T) == 4, "not implemented."); - using Type = uint32_t; -}; - -template -struct DstVec { - static_assert(sizeof(T) == 2, "not implemented."); - using Type = uint2; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Helper function of getting the absMax of all elements in the vector after clamping. -// Pack two elements in order to use possible hmax2 instructions. -template -inline __device__ void clampAndAbsMax(T& localMax, uint4& vec, T const clampMin, T const clampMax) { - static constexpr int NUM_ELTS = sizeof(uint4) / sizeof(T); - -#pragma unroll - for (int i = 0; i < NUM_ELTS; ++i) { - T& val = reinterpret_cast(&vec)[i]; - val = cuda_clamp(val, clampMin, clampMax); - localMax = cuda_max(localMax, cuda_abs(val)); - } -} - -// Helper function of quantizing the vector and storing it to global memory. -// Pack two elements in order to use fast convert instructions. -template -inline __device__ void quantizeAndStore(QuantT* dstPtr, uint4 vec, T const clampMin, - T const clampMax, float const scaleOrigQuant) { - static constexpr int NUM_ELTS = sizeof(uint4) / sizeof(T); - - using DstVecType = typename DstVec::Type; - DstVecType dstVec; -#pragma unroll - for (int i = 0; i < NUM_ELTS; ++i) { - T val = reinterpret_cast(&vec)[i]; - // Values loaded from smem has already been clamped. - if constexpr (!USE_SMEM) { - val = cuda_clamp(val, clampMin, clampMax); - } - float2 val2 = cuda_cast(val); - val2.x *= scaleOrigQuant; - val2.y *= scaleOrigQuant; - QuantT quantVal = cuda_cast(val2); - reinterpret_cast(&dstVec)[i] = quantVal; - } - // Store to destination buffer. - *reinterpret_cast(dstPtr) = dstVec; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template __global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const numRows, int64_t const numCols, float const* clampPtr, float* scalePtr, @@ -253,7 +176,7 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const nu } //////////////////////////////////////////////////////////////////////////////////////////////////// -// FP4/MXFP8 Quantization +// FP4/MXFP8 Quantization Constants constexpr int CVT_FP4_ELTS_PER_THREAD = 8; constexpr int CVT_FP4_SF_VEC_SIZE = 16; @@ -261,482 +184,8 @@ constexpr int CVT_ELTS_PER_THREAD = 8; constexpr int CVT_FP4_THREADS_PER_WARP = 32; constexpr int CVT_FP8_TO_FP4_ELTS_PER_THREAD = 16; -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]), - "f"(array[6]), "f"(array[7])); - return val; -#else - // static_assert(false, "not supported."); - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), - "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - // static_assert(false, "not supported."); - return 0; -#endif -} - -// Convert 8 float2 values into 16 e2m1 values (represented as one uint64_t). -inline __device__ uint64_t fp32_vec_to_e2m1(float2 (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint64_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - ".reg .b8 byte4;\n" - ".reg .b8 byte5;\n" - ".reg .b8 byte6;\n" - ".reg .b8 byte7;\n" - ".reg .b32 val0;\n" - ".reg .b32 val1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte4, %10, %9;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte5, %12, %11;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte6, %14, %13;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte7, %16, %15;\n" - "mov.b32 val0, {byte0, byte1, byte2, byte3};\n" - "mov.b32 val1, {byte4, byte5, byte6, byte7};\n" - "mov.b64 %0, {val0, val1};\n" - "}" - : "=l"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), - "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), "f"(array[4].x), "f"(array[4].y), - "f"(array[5].x), "f"(array[5].y), "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), - "f"(array[7].y)); - return val; -#else - // static_assert(false, "not supported."); - return 0; -#endif -} - -// Convert 4 float2 values into 8 e4m3 values (represented as one uint64_t). -inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) { - union { - uint64_t val; - __nv_fp8x2_e4m3 elts[4]; - } u; - - static_assert(sizeof(u.val) == sizeof(u.elts), - "Expected to alias uint64_t and __nv_fp8x2_e4m3[4]"); - - u.elts[0] = __nv_fp8x2_e4m3(array[0]); - u.elts[1] = __nv_fp8x2_e4m3(array[1]); - u.elts[2] = __nv_fp8x2_e4m3(array[2]); - u.elts[3] = __nv_fp8x2_e4m3(array[3]); - return u.val; -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -__device__ __forceinline__ float exp2f_rcp(uint8_t exp) { - constexpr uint32_t FP32_EXPONENT_BIAS = 127; - return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(exp)); -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; - static_assert(sizeof(elts) == sizeof(Type) * CVT_ELTS_PER_THREAD, - "Vector size should match the number of elements per thread."); -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; - static_assert(sizeof(elts) == sizeof(__nv_fp8_e4m3) * CVT_FP8_TO_FP4_ELTS_PER_THREAD, - "Vector size should match the number of elements per thread."); -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = cuda_abs(vec.elts[0]); - -// Local maximum value. -#pragma unroll - for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { - localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); - } - - constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; - // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). - localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - if constexpr (CVT_NUM_THREADS_PER_SF == 4) { - localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); - } - // Get the final absolute maximum values. - float vecMax = float(cuda_max(localMax.x, localMax.y)); - - // 8 bits representation of the SF. - uint8_t fp8SFVal; - float outputScale; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - __nv_fp8_e8m0 tmp; - // Scale the max value to the range of E2m1. - vecMax *= reciprocal_approximate_ftz(6.0f); - tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf); - - fp8SFVal = tmp.__x; - outputScale = vecMax != 0 ? exp2f_rcp(fp8SFVal) : 0.0f; - } else { - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - auto SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - fp8SFVal = tmp.__x; - SFValue = static_cast(tmp); - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal)) * reciprocal(SFScaleVal)) - outputScale = vecMax != 0 - ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - } - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; - -#pragma unroll - for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} - -template -__device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - float const dequant_to_fp16_scale = 6.f * reciprocal_approximate_ftz(SFScaleVal); - - // Dequant fp8 to fp16 - __half2 vec_half2[8]; -#pragma unroll - for (int i = 0; i < CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2; i++) { - float2 tmp = static_cast(vec.elts[i]); - tmp.x *= dequant_to_fp16_scale; - tmp.y *= dequant_to_fp16_scale; - vec_half2[i] = __float22half2_rn(tmp); - } - - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec_half2[0]); - // Local maximum value. -#pragma unroll - for (int i = 1; i < CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec_half2[i])); - } - - constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP8_TO_FP4_ELTS_PER_THREAD; - if constexpr (CVT_NUM_THREADS_PER_SF == 2) { - // For block 32, we need to reduce the local max across two threads. - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - } - - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - float SFValueNarrow; - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - __nv_fp8_e8m0 tmp; - tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); - SFValueNarrow = static_cast(tmp); - fp8SFVal = tmp.__x; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - fp8SFVal = tmp.__x; - SFValueNarrow = static_cast(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) - float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValueNarrow) : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2]; - -#pragma unroll - for (int i = 0; i < CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2; i++) { - fp2Vals[i] = __half22float2(vec_half2[i]); - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint64_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} - -// Quantizes the provided PackedVec into the uint64_t output -template -__device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec& vec, uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = cuda_abs(vec.elts[0]); - -// Local maximum value. -#pragma unroll - for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { - localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); - } - - constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; - // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). - localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - if constexpr (CVT_NUM_THREADS_PER_SF == 4) { - localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); - } - // Get the final absolute maximum values. - float vecMax = float(cuda_max(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of mxfp8). - float SFValue = vecMax * reciprocal_approximate_ftz(448.0f); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - __nv_fp8_e8m0 tmpSFVal; - tmpSFVal.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); - SFValue = static_cast(tmpSFVal); - fp8SFVal = tmpSFVal.__x; - // Get the output scale (reciprocal of the SFValue). - float outputScale = vecMax != 0.f ? reciprocal_approximate_ftz(SFValue) : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; - -#pragma unroll - for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e4m3 values. - uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals); - - // Write the e4m3 values to global memory. - return e4m3Vec; -#else - return 0; -#endif -} - -inline __device__ __host__ int64_t get_sf_out_offset_128x4(std::optional batchIdx, int mIdx, - int kIdx, std::optional numRows, - int numColVecs) { - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - // batched tensor - // SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4 * innerKStride; // 4 - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * innerMStride; // 16 - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * outerMStride; // 512 - - // SF vector size 16 or 32. We round the "numCols" up to a multiple of 64 or 128. - // It is the same as rounding the "numColVecs" up to a multiple of 4. - int32_t numKTiles = (numColVecs + 4 - 1) / 4; - - int32_t mTileIdx = mIdx / (32 * 4); - int64_t mTileStride = numKTiles * kTileStride; - - // Each SF block has 128 rows so pad rows to the multiple of 128. - int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128; - int64_t bTileStride = numMTiles * mTileStride; - - // Compute the global offset. - int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + - kTileIdx * kTileStride + outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return SFOffset; -} - -inline __device__ __host__ int64_t get_sf_out_offset_8x4(std::optional batchIdx, int mIdx, - int kIdx, std::optional numRows, - int numCols) { - // SF layout [numMTiles, numKTiles, 8 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, innerMIdx, innerKIdx] - - // batched tensor - // SF layout [numBTiles, numMTiles, numKTiles, 8 (mTile), 4(kTile)] - // --> index [bTileIdx, mTileIdx, kTileIdx, innerMIdx, innerKIdx] - const int32_t mTile = 8; - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - int32_t innerMIdx = (mIdx % mTile); - int64_t mStride = 4 * innerKStride; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = mTile * mStride; - - int32_t numKTiles = (numCols + 4 - 1) / 4; - int32_t mTileIdx = mIdx / mTile; - int64_t mTileStride = numKTiles * kTileStride; - - int32_t numMTiles = (numRows.value_or(0) + 8 - 1) / 8; - int64_t bTileStride = numMTiles * mTileStride; - - int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + - kTileIdx * kTileStride + innerMIdx * mStride + innerKIdx * innerKStride; - - return SFOffset; -} - -template -__device__ uint8_t* cvt_quant_get_sf_out_offset(std::optional batchIdx, int rowIdx, - int colVecIdx, std::optional numRows, - int numColVecs, SFType* SFout, - QuantizationSFLayout layout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_NUM_THREADS_PER_SF == 1 || CVT_NUM_THREADS_PER_SF == 2 || - CVT_NUM_THREADS_PER_SF == 4); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_NUM_THREADS_PER_SF == 0) { - if (layout == QuantizationSFLayout::SWIZZLED_128x4 || - layout == QuantizationSFLayout::SWIZZLED_8x4) { - // SF vector index (16 elements share one SF in the K dimension). - // numRows and numCols are unpadded. - int32_t kIdx = colVecIdx / CVT_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - auto SFOffset = layout == QuantizationSFLayout::SWIZZLED_128x4 - ? get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numColVecs) - : get_sf_out_offset_8x4(batchIdx, mIdx, kIdx, numRows, numColVecs); - return reinterpret_cast(SFout) + SFOffset; - } else if (layout == QuantizationSFLayout::LINEAR) { - // Linear row-major layout, no padding required. - int32_t KTileIdx = colVecIdx / CVT_NUM_THREADS_PER_SF; - - int32_t numKTiles = numColVecs; - int64_t mTileStride = numKTiles; - - int64_t BTileStride = numRows.value_or(0) * mTileStride; - - int64_t SFOffset = batchIdx.value_or(0) * BTileStride + rowIdx * mTileStride + KTileIdx; - return reinterpret_cast(SFout) + SFOffset; - } else { - return nullptr; - } - } -#endif - return nullptr; -} +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FP4/MXFP8 Quantization Kernels template __global__ void @@ -754,9 +203,9 @@ quantize_with_block_size( ? CVT_FP8_TO_FP4_ELTS_PER_THREAD : CVT_ELTS_PER_THREAD; - using PackedVec = PackedVec; + using PackedVecT = PackedVec; static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD; // 2 or 4 - static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched."); + static_assert(sizeof(PackedVecT) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched."); // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)). @@ -844,19 +293,20 @@ quantize_with_block_size( } } else { // Load the input vector. - PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVecT in_vec = reinterpret_cast(in)[inOffset]; // Dispatch the quantization kernel. if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + cvt_warp_fp16_to_fp4( + in_vec, SFScaleVal, sf_out); } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { reinterpret_cast(out)[outOffset] = - cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, - sf_out); + cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, ELTS_PER_THREAD, UE8M0_SF>( + in_vec, SFScaleVal, sf_out); } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + cvt_warp_fp16_to_mxfp8(in_vec, sf_out); } } } @@ -867,75 +317,192 @@ quantize_with_block_size( #endif } -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, - SFType* SFout) { +// quantize with TMA in high throughput mode +template +__global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; +__launch_bounds__(288, 2) quantize_with_block_size_tma( +#else +quantize_with_block_size_tma( +#endif + int32_t numbatches, int32_t numRows, int32_t numCols, int32_t numPaddedCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, QuantizationSFLayout layout, + const __grid_constant__ CUtensorMap tensor_map) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using Traits = TmaKernelTraits; + using SmemType = typename Traits::SmemType; + + static constexpr int ELTS_PER_THREAD = Traits::ELTS_PER_THREAD; + static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD; + static constexpr int TMA_ROW_TILE = Traits::TMA_ROW_TILE; + static constexpr int TMA_COL_TILE = Traits::TMA_COL_TILE; + static constexpr int NUM_STAGES = Traits::NUM_STAGES; + static constexpr int THREADS_PER_ROW = Traits::THREADS_PER_ROW; + static constexpr int ROWS_PER_WARP = Traits::ROWS_PER_WARP; + static constexpr int ROW_ITERATIONS = Traits::ROW_ITERATIONS; + static constexpr int SMEM_STAGE_SIZE = Traits::SMEM_STAGE_SIZE; + static constexpr int NUM_CONSUMER_WARPS = Traits::NUM_CONSUMER_WARPS; + + using PackedVecT = PackedVec; + static_assert(sizeof(PackedVecT) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched."); + static_assert(SF_VEC_SIZE == 16, "Only support SF_VEC_SIZE = 16 for TMA quantization."); + + int warpIdx = threadIdx.x / 32; + int numWarp = blockDim.x / 32; + int laneIdx = threadIdx.x % 32; + + // IMPORTANT: TMA with SWIZZLE_128B requires 128-byte aligned shared memory. + extern __shared__ __align__(1024) uint8_t smem_raw[]; + + // SMEM data starts at the beginning of dynamic shared memory (128-byte aligned) + SmemType* smem = reinterpret_cast(smem_raw); + + // Place barriers at the end of dynamic shared memory + Barrier* barrier_start_ptr = reinterpret_cast(smem_raw + Traits::SMEM_DATA_SIZE); + + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = + PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (NUM_STAGES + i); }); + + // Get the global scaling factor + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + // Is it swizzled layout? + bool isSfSwizzledLayout = layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4; - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; + // The number of padded rows considering 128x4 or 8x4 SF layout. + int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; + int numPaddedRowsForSf = isSfSwizzledLayout ? PadUpFn(numRows, rowTile) : numRows; + int numColsForSf = isSfSwizzledLayout ? PadUpFn(numPaddedCols, 4 * SF_VEC_SIZE) : numPaddedCols; - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; + asm volatile("griddepcontrol.wait;"); - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; + // TMA barrier initialization. + if (warpIdx == 0 and laneIdx == 0) { +#pragma unroll + for (int i = 0; i < NUM_STAGES; i++) { + full_barriers[i]->init(1); + empty_barriers[i]->init(NUM_CONSUMER_WARPS); +#pragma unroll + for (int j = 0; j < NUM_CONSUMER_WARPS; j++) { + empty_barriers[i]->arrive(); + } + } + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; + uint32_t stage_idx = 0, phase = 0; - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; + if (warpIdx == 0 and elect_one_sync()) { + // Producer warp - TMA loads + for (int rowIdx = blockIdx.x * TMA_ROW_TILE; rowIdx < numPaddedRowsForSf; + rowIdx += gridDim.x * TMA_ROW_TILE) { + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + for (int colIdx = 0; colIdx < numCols; colIdx += NUM_CONSUMER_WARPS * TMA_COL_TILE) { + empty_barriers[stage_idx]->wait(phase); + + // Use batchIdx * numRows + rowIdx to access the correct batch in the flattened + // [B*M, N] tensor. The tensor map is created with total rows = B * M. + cute::SM90_TMA_LOAD_3D::copy(&tensor_map, + reinterpret_cast(full_barriers[stage_idx]), 0ULL, + smem + stage_idx * SMEM_STAGE_SIZE, 0, + batchIdx * numRows + rowIdx, colIdx / TMA_COL_TILE); + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_STAGE_SIZE * sizeof(SmemType)); + + stage_idx = stage_idx == NUM_STAGES - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + } + } + } + } else if (warpIdx >= 1 and warpIdx <= 8) { + // Consumer warps + int consumerWarpIdx = warpIdx - 1; + typename Traits::ThreadIndexing tidx(laneIdx, consumerWarpIdx); - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + - innerMIdx * innerMStride + innerKIdx * innerKStride; + for (int rowIdx = blockIdx.x * TMA_ROW_TILE; rowIdx < numPaddedRowsForSf; + rowIdx += gridDim.x * TMA_ROW_TILE) { + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + std::optional optionalBatchIdx = batchIdx; + std::optional optionalNumRows = numRows; + tidx.reset(); // Reset column indices for each row iteration - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} + int threadRowIdxGlobal; + int64_t rowOffset, threadOutOffset; -__device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); } + for (int colIdx = 0; colIdx < numCols; colIdx += NUM_CONSUMER_WARPS * TMA_COL_TILE) { + threadRowIdxGlobal = rowIdx + tidx.rowIdxLocal; + rowOffset = static_cast(batchIdx * numRows + threadRowIdxGlobal) * numPaddedCols; + threadOutOffset = (rowOffset + tidx.colIdx) >> 4; -template -inline __device__ void silu_and_mul(PackedVec& x_vec, const PackedVec& y_vec) { - float2 x[CVT_FP4_ELTS_PER_THREAD / 2]; - float2 y[CVT_FP4_ELTS_PER_THREAD / 2]; + full_barriers[stage_idx]->wait(phase); #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - x[i] = __half22float2(x_vec.elts[i]); - y[i] = __half22float2(y_vec.elts[i]); - x[i].x = silu(x[i].x) * y[i].x; - x[i].y = silu(x[i].y) * y[i].y; - x_vec.elts[i] = __float22half2_rn(x[i]); - } else { - x[i] = __bfloat1622float2(x_vec.elts[i]); - y[i] = __bfloat1622float2(y_vec.elts[i]); - x[i].x = silu(x[i].x) * y[i].x; - x[i].y = silu(x[i].y) * y[i].y; - x_vec.elts[i] = __float22bfloat162_rn(x[i]); + for (int i = 0; i < ROW_ITERATIONS; i++) { + auto sf_out = cvt_quant_get_sf_out_offset( + optionalBatchIdx, threadRowIdxGlobal, tidx.colVecIdx, optionalNumRows, + numPaddedCols / SF_VEC_SIZE, SFout, layout); + + // Set padded columns to 0 + if (threadRowIdxGlobal < numRows && tidx.colIdx >= numCols && + tidx.colIdx < numPaddedCols) { + reinterpret_cast(out)[threadOutOffset] = 0ull; + } + + // Set SF padding to 0 + if (threadRowIdxGlobal >= numRows || tidx.colIdx >= numCols) { + if (sf_out != nullptr) { + sf_out[0] = 0x00; + } + } else { + SmemType* smem_stage = smem + stage_idx * SMEM_STAGE_SIZE; + float4 const* base_float4 = reinterpret_cast( + smem_stage + consumerWarpIdx * TMA_COL_TILE * TMA_ROW_TILE + + i * TMA_COL_TILE * ROWS_PER_WARP); + + // Load input vector from shared memory + PackedVecT in_vec = Traits::template load_input_vec( + base_float4, tidx.rowIdxLocal, tidx.colIdxLocal); + + // Dispatch the quantization kernel + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[threadOutOffset] = + cvt_warp_fp16_to_fp4( + in_vec, SFScaleVal, sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { + reinterpret_cast(out)[threadOutOffset] = + cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, ELTS_PER_THREAD, UE8M0_SF>( + in_vec, SFScaleVal, sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[threadOutOffset] = + cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + } + } + + // Update row index and output offset + threadRowIdxGlobal += ROWS_PER_WARP; + rowOffset = + static_cast(batchIdx * numRows + threadRowIdxGlobal) * numPaddedCols; + threadOutOffset = (rowOffset + tidx.colIdx) >> 4; + } + + // Update column offset + tidx.advance_col(); + threadOutOffset = (rowOffset + tidx.colIdx) >> 4; + + if (laneIdx == 0) { + empty_barriers[stage_idx]->arrive(); + } + + stage_idx = stage_idx == NUM_STAGES - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + } + } } } + asm volatile("griddepcontrol.launch_dependents;"); +#endif } // Use UE4M3 by default. @@ -949,9 +516,9 @@ cvt_fp16_to_fp4_expert( int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out, uint32_t* SFout, int32_t* mask, bool use_silu_and_mul, int n_experts) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - using PackedVec = PackedVec; + using PackedVecT = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + static_assert(sizeof(PackedVecT) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); // Input tensor row/col loops. @@ -1003,10 +570,10 @@ cvt_fp16_to_fp4_expert( } int64_t inOffset = rowIdx * actualColsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVecT in_vec = reinterpret_cast(in)[inOffset]; if (use_silu_and_mul) { - PackedVec in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; - silu_and_mul(in_vec, in_vec_mul); + PackedVecT in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; + silu_and_mul(in_vec, in_vec_mul); } // Get the output tensor offset. @@ -1025,10 +592,12 @@ cvt_fp16_to_fp4_expert( int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout; - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( rowIdx_in_expert, colIdx, numCols, SFout_in_expert); - out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + out_pos = cvt_warp_fp16_to_fp4( + in_vec, SFScaleVal, sf_out); } #endif } diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh new file mode 100644 index 0000000000..669bcccd93 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh @@ -0,0 +1,888 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/kernels/quantization.h" + +using namespace tensorrt_llm::common; + +namespace tensorrt_llm { +namespace kernels { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// DstVec type traits for quantization + +template +struct DstVec { + static_assert("not implemented."); +}; + +template <> +struct DstVec { + using Type = uint32_t; +}; + +template <> +struct DstVec { + using Type = uint2; +}; + +#ifdef ENABLE_BF16 + +template <> +struct DstVec<__nv_bfloat162, 4> { + using Type = uint2; +}; + +#endif // ENABLE_BF16 + +template +struct DstVec { + static_assert(sizeof(T) == 4, "not implemented."); + using Type = uint32_t; +}; + +template +struct DstVec { + static_assert(sizeof(T) == 2, "not implemented."); + using Type = uint2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Helper function of getting the absMax of all elements in the vector after clamping. +// Pack two elements in order to use possible hmax2 instructions. +template +inline __device__ void clampAndAbsMax(T& localMax, uint4& vec, T const clampMin, T const clampMax) { + static constexpr int NUM_ELTS = sizeof(uint4) / sizeof(T); + +#pragma unroll + for (int i = 0; i < NUM_ELTS; ++i) { + T& val = reinterpret_cast(&vec)[i]; + val = cuda_clamp(val, clampMin, clampMax); + localMax = cuda_max(localMax, cuda_abs(val)); + } +} + +// Helper function of quantizing the vector and storing it to global memory. +// Pack two elements in order to use fast convert instructions. +template +inline __device__ void quantizeAndStore(QuantT* dstPtr, uint4 vec, T const clampMin, + T const clampMax, float const scaleOrigQuant) { + static constexpr int NUM_ELTS = sizeof(uint4) / sizeof(T); + + using DstVecType = typename DstVec::Type; + DstVecType dstVec; +#pragma unroll + for (int i = 0; i < NUM_ELTS; ++i) { + T val = reinterpret_cast(&vec)[i]; + // Values loaded from smem has already been clamped. + if constexpr (!USE_SMEM) { + val = cuda_clamp(val, clampMin, clampMax); + } + float2 val2 = cuda_cast(val); + val2.x *= scaleOrigQuant; + val2.y *= scaleOrigQuant; + QuantT quantVal = cuda_cast(val2); + reinterpret_cast(&dstVec)[i] = quantVal; + } + // Store to destination buffer. + *reinterpret_cast(dstPtr) = dstVec; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FP4/MXFP8 Conversion Functions + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]), + "f"(array[6]), "f"(array[7])); + return val; +#else + // static_assert(false, "not supported."); + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), + "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + // static_assert(false, "not supported."); + return 0; +#endif +} + +// Convert 8 float2 values into 16 e2m1 values (represented as one uint64_t). +inline __device__ uint64_t fp32_vec_to_e2m1(float2 (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint64_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + ".reg .b8 byte4;\n" + ".reg .b8 byte5;\n" + ".reg .b8 byte6;\n" + ".reg .b8 byte7;\n" + ".reg .b32 val0;\n" + ".reg .b32 val1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte4, %10, %9;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte5, %12, %11;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte6, %14, %13;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte7, %16, %15;\n" + "mov.b32 val0, {byte0, byte1, byte2, byte3};\n" + "mov.b32 val1, {byte4, byte5, byte6, byte7};\n" + "mov.b64 %0, {val0, val1};\n" + "}" + : "=l"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), + "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), "f"(array[4].x), "f"(array[4].y), + "f"(array[5].x), "f"(array[5].y), "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), + "f"(array[7].y)); + return val; +#else + // static_assert(false, "not supported."); + return 0; +#endif +} + +inline __device__ uint32_t elect_one_sync() { + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} + +// Convert 4 float2 values into 8 e4m3 values (represented as one uint64_t). +inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) { + union { + uint64_t val; + __nv_fp8x2_e4m3 elts[4]; + } u; + + static_assert(sizeof(u.val) == sizeof(u.elts), + "Expected to alias uint64_t and __nv_fp8x2_e4m3[4]"); + + u.elts[0] = __nv_fp8x2_e4m3(array[0]); + u.elts[1] = __nv_fp8x2_e4m3(array[1]); + u.elts[2] = __nv_fp8x2_e4m3(array[2]); + u.elts[3] = __nv_fp8x2_e4m3(array[3]); + return u.val; +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +__device__ __forceinline__ float exp2f_rcp(uint8_t exp) { + constexpr uint32_t FP32_EXPONENT_BIAS = 127; + return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(exp)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Type converters for packed vectors + +template +struct TypeConverter { + using PackedType = void; +}; + +template <> +struct TypeConverter { + using PackedType = half2; +}; + +#ifdef ENABLE_BF16 +template <> +struct TypeConverter<__nv_bfloat16> { + using PackedType = __nv_bfloat162; +}; +#endif + +// Define a packed data type parameterized by the number of elements. +// For half/bf16: uses half2/bfloat162, so NUM_ELTS elements require NUM_ELTS/2 pairs. +// For FP8: uses __nv_fp8x2_e4m3, so NUM_ELTS elements require NUM_ELTS/2 pairs. +template +struct PackedVec { + typename TypeConverter::PackedType elts[NUM_ELTS / 2]; + static_assert(sizeof(elts) == sizeof(Type) * NUM_ELTS, + "Vector size should match the number of elements per thread."); +}; + +// Specialization for FP8 with default 16 elements +template +struct PackedVec<__nv_fp8_e4m3, NUM_ELTS> { + __nv_fp8x2_e4m3 elts[NUM_ELTS / 2]; + static_assert(sizeof(elts) == sizeof(__nv_fp8_e4m3) * NUM_ELTS, + "Vector size should match the number of elements per thread."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Quantization helper functions + +// Quantizes the provided PackedVec into the uint32_t or uint64_t output +template +__device__ std::conditional_t cvt_warp_fp16_to_fp4( + PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_ELTS_PER_THREAD == 8 || CVT_ELTS_PER_THREAD == 16, + "CVT_ELTS_PER_THREAD must be 8 or 16"); + + using ReturnType = std::conditional_t; + + // Get absolute maximum values among the local 8 values. + auto localMax = cuda_abs(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { + localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); + } + + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; + // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). + if constexpr (CVT_NUM_THREADS_PER_SF >= 2) { + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + } + if constexpr (CVT_NUM_THREADS_PER_SF == 4) { + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(cuda_max(localMax.x, localMax.y)); + + // 8 bits representation of the SF. + uint8_t fp8SFVal; + float outputScale; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + __nv_fp8_e8m0 tmp; + // Scale the max value to the range of E2m1. + vecMax *= reciprocal_approximate_ftz(6.0f); + tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf); + + fp8SFVal = tmp.__x; + outputScale = vecMax != 0 ? exp2f_rcp(fp8SFVal) : 0.0f; + } else { + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + auto SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + fp8SFVal = tmp.__x; + SFValue = static_cast(tmp); + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal)) * reciprocal(SFScaleVal)) + outputScale = vecMax != 0 + ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + } + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + ReturnType e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +template +__device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + // Because the return value is a uint64_t, we need to ensure that the CVT_ELTS_PER_THREAD is 16. + static_assert(CVT_ELTS_PER_THREAD == 16, "CVT_ELTS_PER_THREAD must be 16"); + + float const dequant_to_fp16_scale = 6.f * reciprocal_approximate_ftz(SFScaleVal); + + // Dequant fp8 to fp16 + __half2 vec_half2[8]; +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + float2 tmp = static_cast(vec.elts[i]); + tmp.x *= dequant_to_fp16_scale; + tmp.y *= dequant_to_fp16_scale; + vec_half2[i] = __float22half2_rn(tmp); + } + + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec_half2[0]); + // Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec_half2[i])); + } + + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; + if constexpr (CVT_NUM_THREADS_PER_SF == 2) { + // For block 32, we need to reduce the local max across two threads. + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + } + + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + float SFValueNarrow; + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + __nv_fp8_e8m0 tmp; + tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValueNarrow = static_cast(tmp); + fp8SFVal = tmp.__x; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + fp8SFVal = tmp.__x; + SFValueNarrow = static_cast(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) + float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValueNarrow) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i] = __half22float2(vec_half2[i]); + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint64_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Quantizes the provided PackedVec into the uint64_t output +template +__device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec& vec, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = cuda_abs(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { + localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); + } + + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; + // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). + if constexpr (CVT_NUM_THREADS_PER_SF >= 2) { + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + } + if constexpr (CVT_NUM_THREADS_PER_SF == 4) { + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(cuda_max(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of mxfp8). + float SFValue = vecMax * reciprocal_approximate_ftz(448.0f); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + __nv_fp8_e8m0 tmpSFVal; + tmpSFVal.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + SFValue = static_cast(tmpSFVal); + fp8SFVal = tmpSFVal.__x; + // Get the output scale (reciprocal of the SFValue). + float outputScale = vecMax != 0.f ? reciprocal_approximate_ftz(SFValue) : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e4m3 values. + uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals); + + // Write the e4m3 values to global memory. + return e4m3Vec; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Scale factor offset calculation functions + +inline __device__ __host__ int64_t get_sf_out_offset_128x4(std::optional batchIdx, int mIdx, + int kIdx, std::optional numRows, + int numColVecs) { + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + // batched tensor + // SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4 * innerKStride; // 4 + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * innerMStride; // 16 + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * outerMStride; // 512 + + // SF vector size 16 or 32. We round the "numCols" up to a multiple of 64 or 128. + // It is the same as rounding the "numColVecs" up to a multiple of 4. + int32_t numKTiles = (numColVecs + 4 - 1) / 4; + + int32_t mTileIdx = mIdx / (32 * 4); + int64_t mTileStride = numKTiles * kTileStride; + + // Each SF block has 128 rows so pad rows to the multiple of 128. + int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128; + int64_t bTileStride = numMTiles * mTileStride; + + // Compute the global offset. + int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + + kTileIdx * kTileStride + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return SFOffset; +} + +inline __device__ __host__ int64_t get_sf_out_offset_8x4(std::optional batchIdx, int mIdx, + int kIdx, std::optional numRows, + int numCols) { + // SF layout [numMTiles, numKTiles, 8 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, innerMIdx, innerKIdx] + + // batched tensor + // SF layout [numBTiles, numMTiles, numKTiles, 8 (mTile), 4(kTile)] + // --> index [bTileIdx, mTileIdx, kTileIdx, innerMIdx, innerKIdx] + const int32_t mTile = 8; + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + int32_t innerMIdx = (mIdx % mTile); + int64_t mStride = 4 * innerKStride; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = mTile * mStride; + + int32_t numKTiles = (numCols + 4 - 1) / 4; + int32_t mTileIdx = mIdx / mTile; + int64_t mTileStride = numKTiles * kTileStride; + + int32_t numMTiles = (numRows.value_or(0) + 8 - 1) / 8; + int64_t bTileStride = numMTiles * mTileStride; + + int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + + kTileIdx * kTileStride + innerMIdx * mStride + innerKIdx * innerKStride; + + return SFOffset; +} + +template +__device__ uint8_t* cvt_quant_get_sf_out_offset(std::optional batchIdx, int rowIdx, + int colVecIdx, std::optional numRows, + int numColVecs, SFType* SFout, + QuantizationSFLayout layout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_NUM_THREADS_PER_SF == 1 || CVT_NUM_THREADS_PER_SF == 2 || + CVT_NUM_THREADS_PER_SF == 4); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_NUM_THREADS_PER_SF == 0) { + if (layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4) { + // SF vector index (16 elements share one SF in the K dimension). + // numRows and numCols are unpadded. + int32_t kIdx = colVecIdx / CVT_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + auto SFOffset = layout == QuantizationSFLayout::SWIZZLED_128x4 + ? get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numColVecs) + : get_sf_out_offset_8x4(batchIdx, mIdx, kIdx, numRows, numColVecs); + return reinterpret_cast(SFout) + SFOffset; + } else if (layout == QuantizationSFLayout::LINEAR) { + // Linear row-major layout, no padding required. + int32_t KTileIdx = colVecIdx / CVT_NUM_THREADS_PER_SF; + + int32_t numKTiles = numColVecs; + int64_t mTileStride = numKTiles; + + int64_t BTileStride = numRows.value_or(0) * mTileStride; + + int64_t SFOffset = batchIdx.value_or(0) * BTileStride + rowIdx * mTileStride + KTileIdx; + return reinterpret_cast(SFout) + SFOffset; + } else { + return nullptr; + } + } +#endif + return nullptr; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + + innerMIdx * innerMStride + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +__device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); } + +template +inline __device__ void silu_and_mul(PackedVec& x_vec, + const PackedVec& y_vec) { + float2 x[CVT_ELTS_PER_THREAD / 2]; + float2 y[CVT_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + x[i] = __half22float2(x_vec.elts[i]); + y[i] = __half22float2(y_vec.elts[i]); + x[i].x = silu(x[i].x) * y[i].x; + x[i].y = silu(x[i].y) * y[i].y; + x_vec.elts[i] = __float22half2_rn(x[i]); + } else { + x[i] = __bfloat1622float2(x_vec.elts[i]); + y[i] = __bfloat1622float2(y_vec.elts[i]); + x[i].x = silu(x[i].x) * y[i].x; + x[i].y = silu(x[i].y) * y[i].y; + x_vec.elts[i] = __float22bfloat162_rn(x[i]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper functions for quantization kernel with TMA in high throughput mode + +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ explicit PatternVisitor(FuncT&& func) : func(std::forward(func)) {} + + __device__ __host__ auto operator[](const uint32_t& i) { return func(i); } +}; + +template +struct TmaKernelTraits; + +// Base template for 2-byte types (half, __nv_bfloat16) +// 2 bytes per element, 16 elements per thread = 32 bytes = 2 float4s +template +struct TmaKernelTraitsTwoBytes { + using InputType = T; + using SmemType = T; + + static constexpr int TMA_ROW_TILE = 16; + static constexpr int TMA_COL_TILE = 64; // 64 elements = 128 bytes + static constexpr int NUM_STAGES = 4; + static constexpr int SMEM_ROWS = TMA_ROW_TILE; // Must match TMA_ROW_TILE for TMA loads + static constexpr int SMEM_COLS = 8 * TMA_COL_TILE; // 8 warps * 64 cols + static constexpr int THREADS_PER_ROW = 4; // laneIdx % 4 + static constexpr int ROWS_PER_WARP = 8; // 32 / 4 + static constexpr int ROW_ITERATIONS = TMA_ROW_TILE / ROWS_PER_WARP; // 2 + static constexpr int ELTS_PER_THREAD = 16; + static constexpr int NUM_CONSUMER_WARPS = 8; + + static constexpr size_t SMEM_DATA_SIZE = NUM_STAGES * SMEM_ROWS * SMEM_COLS * sizeof(SmemType); + static constexpr int SMEM_STAGE_SIZE = SMEM_ROWS * SMEM_COLS; + + // Thread indexing helper - encapsulates all index calculations + struct ThreadIndexing { + int const colIdxLocal; // Thread's local column index within warp tile (constant) + int const rowIdxLocal; // Thread's local row index within warp (constant) + int const baseColIdx; // Base column index for this thread (constant) + int const baseColVecIdx; // Base column vector index (constant) + int colIdx; // Thread's global column index (in elements) + int colVecIdx; // Thread's column index in SF vector units + + __device__ ThreadIndexing(int laneIdx, int consumerWarpIdx) + : colIdxLocal(laneIdx % THREADS_PER_ROW), + rowIdxLocal(laneIdx / THREADS_PER_ROW), + baseColIdx(consumerWarpIdx * TMA_COL_TILE + colIdxLocal * ELTS_PER_THREAD), + baseColVecIdx(consumerWarpIdx * (TMA_COL_TILE / ELTS_PER_THREAD) + colIdxLocal), + colIdx(baseColIdx), + colVecIdx(baseColVecIdx) {} + + __device__ void reset() { + colIdx = baseColIdx; + colVecIdx = baseColVecIdx; + } + + __device__ void advance_col() { + colIdx += NUM_CONSUMER_WARPS * TMA_COL_TILE; + colVecIdx = colIdx / ELTS_PER_THREAD; + } + }; + + // Load input vector from shared memory for 2-byte types + // Uses SWIZZLE_128B indexing, loads 2 float4s (32 bytes = 16 elements) + template + __device__ static PackedVecT load_input_vec(float4 const* base_float4, int threadRowIdxLocal, + int threadColIdxLocal) { + // Compute swizzled indices for SWIZZLE_128B + int swizzled_col = threadColIdxLocal * 2; // Each thread reads 2 float4s + int col_after_swizzle_0 = threadRowIdxLocal ^ swizzled_col; + int col_after_swizzle_1 = threadRowIdxLocal ^ (swizzled_col + 1); + int float4_idx_0 = threadRowIdxLocal * TMA_COL_TILE / 8 + col_after_swizzle_0; + int float4_idx_1 = threadRowIdxLocal * TMA_COL_TILE / 8 + col_after_swizzle_1; + + // Load 2 float4s (32 bytes) + float4 load_data[2]; + load_data[0] = base_float4[float4_idx_0]; + load_data[1] = base_float4[float4_idx_1]; + return reinterpret_cast(load_data[0]); + } +}; + +// Specialization for half +template <> +struct TmaKernelTraits : TmaKernelTraitsTwoBytes {}; + +// Specialization for BF16 +#ifdef ENABLE_BF16 +template <> +struct TmaKernelTraits<__nv_bfloat16> : TmaKernelTraitsTwoBytes<__nv_bfloat16> {}; +#endif + +// Specialization for FP8 input (FP8_TO_FP4 native) +// FP8: 1 byte per element, 16 elements per thread = 16 bytes = 1 float4 +template <> +struct TmaKernelTraits<__nv_fp8_e4m3> { + using InputType = __nv_fp8_e4m3; + using SmemType = __nv_fp8_e4m3; + + static constexpr int TMA_ROW_TILE = 8; + static constexpr int TMA_COL_TILE = 128; // 128 FP8 elements = 128 bytes + static constexpr int NUM_STAGES = 6; + static constexpr int SMEM_ROWS = TMA_ROW_TILE; // Must match TMA_ROW_TILE for TMA loads + static constexpr int SMEM_COLS = 8 * TMA_COL_TILE; // 8 warps * 128 cols + static constexpr int THREADS_PER_ROW = 8; // laneIdx % 8 + static constexpr int ROWS_PER_WARP = 4; // 32 / 8 + static constexpr int ROW_ITERATIONS = TMA_ROW_TILE / ROWS_PER_WARP; // 2 + static constexpr int ELTS_PER_THREAD = 16; + static constexpr int NUM_CONSUMER_WARPS = 8; + + static constexpr size_t SMEM_DATA_SIZE = NUM_STAGES * SMEM_ROWS * SMEM_COLS * sizeof(SmemType); + static constexpr int SMEM_STAGE_SIZE = SMEM_ROWS * SMEM_COLS; + + // Thread indexing helper - encapsulates all index calculations + struct ThreadIndexing { + int const colIdxLocal; // Thread's local column index within warp tile (constant) + int const rowIdxLocal; // Thread's local row index within warp (constant) + int const baseColIdx; // Base column index for this thread (constant) + int const baseColVecIdx; // Base column vector index (constant) + int colIdx; // Thread's global column index (in elements) + int colVecIdx; // Thread's column index in SF vector units + + __device__ ThreadIndexing(int laneIdx, int consumerWarpIdx) + : colIdxLocal(laneIdx % THREADS_PER_ROW), + rowIdxLocal(laneIdx / THREADS_PER_ROW), + baseColIdx(consumerWarpIdx * TMA_COL_TILE + colIdxLocal * ELTS_PER_THREAD), + baseColVecIdx(consumerWarpIdx * (TMA_COL_TILE / ELTS_PER_THREAD) + colIdxLocal), + colIdx(baseColIdx), + colVecIdx(baseColVecIdx) {} + + __device__ void reset() { + colIdx = baseColIdx; + colVecIdx = baseColVecIdx; + } + + __device__ void advance_col() { + colIdx += NUM_CONSUMER_WARPS * TMA_COL_TILE; + colVecIdx = colIdx / ELTS_PER_THREAD; + } + }; + + // Load input vector from shared memory for FP8 + // Uses linear indexing (no swizzle), loads 1 float4 (16 bytes = 16 FP8 elements) + template + __device__ static PackedVecT load_input_vec(float4 const* base_float4, int threadRowIdxLocal, + int threadColIdxLocal) { + // Linear indexing: compute float4 offset directly + int float4_idx = threadRowIdxLocal * (TMA_COL_TILE / 16) + threadColIdxLocal; + + // Load 1 float4 (16 bytes) + float4 load_data = base_float4[float4_idx]; + return reinterpret_cast(load_data); + } +}; + +// Shared memory size constants (for kernel launch) +constexpr size_t TMA_BARRIER_SECTION_SIZE = 1024; // Reserved for barriers (aligned) + +template +constexpr size_t get_tma_smem_size() { + return TMA_BARRIER_SECTION_SIZE + TmaKernelTraits::SMEM_DATA_SIZE; +} + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/tests/utils/test_fp4_quantize.py b/tests/utils/test_fp4_quantize.py index 5060210d71..7242d25591 100644 --- a/tests/utils/test_fp4_quantize.py +++ b/tests/utils/test_fp4_quantize.py @@ -19,8 +19,14 @@ DTYPES = [torch.float16, torch.bfloat16] # The batch dimension doesn't need to be multiple of 128 -SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256)] -BATCH_SHAPES = [(1, 256, 128), (2, 128, 64), (3, 256, 128), (1, 120, 64)] +SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256), (2048, 2048)] +BATCH_SHAPES = [ + (1, 256, 128), + (2, 128, 64), + (3, 256, 128), + (1, 120, 64), + (128, 2048, 2048), +] SEEDS = [42] CUDA_DEVICES = ["cuda:0"]