Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -991,12 +991,12 @@ __device__ auto quantizePackedFPXValue(
if constexpr (is_fp8) {
return [](PackedVec<GemmOutputType>& vec, float /* ignored */, uint8_t* SFout) -> uint64_t {
static_assert(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == VecSize);
return cvt_warp_fp16_to_mxfp8<GemmOutputType, VecSize>(vec, SFout);
return cvt_warp_fp16_to_mxfp8<GemmOutputType, VecSize, CVT_ELTS_PER_THREAD>(vec, SFout);
};
} else {
return (scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4)
? &cvt_warp_fp16_to_fp4<GemmOutputType, VecSize, false>
: &cvt_warp_fp16_to_fp4<GemmOutputType, VecSize, true>;
? &cvt_warp_fp16_to_fp4<GemmOutputType, VecSize, CVT_ELTS_PER_THREAD, false>
: &cvt_warp_fp16_to_fp4<GemmOutputType, VecSize, CVT_ELTS_PER_THREAD, true>;
}
}();

Expand Down
155 changes: 154 additions & 1 deletion csrc/nv_internal/cpp/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include <cuda.h>
#include <cudaTypedefs.h>
#include <float.h>

#include "tensorrt_llm/common/assert.h"
Expand Down Expand Up @@ -93,6 +95,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
int32_t* SFOuput, QuantizationSFLayout layout, int multiProcessorCount,
bool enable_pdl, cudaStream_t stream) {
// Fixed SF_VEC_SIZE as 32
// TODO: TMA quantization for MXFP8 is not supported yet because of SF_VEC_SIZE = 32.
static constexpr int SF_VEC_SIZE = 32;

// Grid, Block size.
Expand Down Expand Up @@ -178,16 +181,155 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
// TMA tensor map creation helpers

template <typename T>
CUtensorMap make_3d_tma_copy_desc(T* global_address, uint64_t gmem_dim[3],
uint64_t stride_in_bytes[2], uint32_t smem_dim[3],
CUtensorMapSwizzle swizzle_type) {
CUtensorMap tensor_map{};
constexpr uint32_t rank = 3;
uint32_t elem_strides[rank] = {1, 1, 1};

// Get pointer to cuTensorMapEncodeTiled
cudaDriverEntryPointQueryResult driver_status;
void* cuTensorMapEncodeTiled_ptr = nullptr;

#if CUDA_VERSION >= 12050
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
cudaEnableDefault, &driver_status);
#else
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, cudaEnableDefault,
&driver_status);
#endif

if (driver_status != cudaDriverEntryPointSuccess) {
TLLM_CHECK_WITH_INFO(false, "Failed to get cuTensorMapEncodeTiled entry point");
}

auto encode_func =
reinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(cuTensorMapEncodeTiled_ptr);

CUtensorMapDataType data_type;
if constexpr (std::is_same_v<T, half>) {
data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
}

CUresult result =
encode_func(&tensor_map, data_type, rank, global_address, gmem_dim, stride_in_bytes, smem_dim,
elem_strides, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA);
TLLM_CHECK_WITH_INFO(result == CUDA_SUCCESS, "Failed to encode TMA tensor map");
return tensor_map;
}

////////////////////////////////////////////////////////////////////////////////////////////////////
// FP4/MXFP8 Quantization

// Helper function to launch TMA quantization kernel
template <BlockScaleQuantizationType quantization_type, typename T, int SF_VEC_SIZE>
void launchFP4QuantizationTma(int b, int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
cudaStream_t stream) {
using Traits = TmaKernelTraits<T>;
constexpr int TMA_ROW_TILE = Traits::TMA_ROW_TILE;
constexpr int TMA_COL_TILE = Traits::TMA_COL_TILE;
constexpr int NUM_CONSUMER_WARPS = 8;

// Compute effective rows for swizzled layouts
int effectiveRows = computeEffectiveRows(m, layout);

// Grid and block configuration for TMA kernel
// TMA kernel uses 288 threads: 1 producer warp + 8 consumer warps
dim3 block(288);
// Each block handles TMA_ROW_TILE rows
int numRowTiles = (effectiveRows + TMA_ROW_TILE - 1) / TMA_ROW_TILE;
dim3 grid(std::min(numRowTiles, multiProcessorCount * 2));

// Dynamic shared memory size
size_t smem_size = get_tma_smem_size<T>();

// Create 3D TMA tensor map descriptor
// The TMA kernel loads a box of [TMA_COL_TILE, TMA_ROW_TILE, NUM_CONSUMER_WARPS] elements per TMA
// call Global tensor is treated as [TMA_COL_TILE, B*M, num_tiles] where num_tiles = N /
// TMA_COL_TILE. We use b * m (not b * effectiveRows) because batches are stored contiguously
// without padding between them.
int num_col_tiles = (n + TMA_COL_TILE - 1) / TMA_COL_TILE;
uint64_t gmem_dim[3] = {
static_cast<uint64_t>(TMA_COL_TILE), // Elements per tile (contiguous in memory)
static_cast<uint64_t>(b * m), // Total rows across all batches
static_cast<uint64_t>(num_col_tiles) // Number of column tiles
};
uint64_t stride_in_bytes[2] = {
static_cast<uint64_t>(n * sizeof(T)), // Stride between rows (in bytes)
static_cast<uint64_t>(TMA_COL_TILE * sizeof(T)) // Stride between tiles (in bytes)
};
uint32_t smem_dim[3] = {
static_cast<uint32_t>(TMA_COL_TILE), // Elements loaded per tile
static_cast<uint32_t>(TMA_ROW_TILE), // Rows loaded per TMA call
static_cast<uint32_t>(NUM_CONSUMER_WARPS) // Number of tiles loaded (for 8 consumer warps)
};

// CUtensorMap must be 64-byte aligned
// Use SWIZZLE_128B for half/bf16 (2-byte types), SWIZZLE_NONE for FP8 (1-byte types)
constexpr CUtensorMapSwizzle swizzle_type =
(std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>)
? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B
: CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
alignas(64) CUtensorMap tensor_map = make_3d_tma_copy_desc(
const_cast<T*>(input), gmem_dim, stride_in_bytes, smem_dim, swizzle_type);

// Select and launch the TMA kernel
auto* kernel_instance =
useUE8M0 ? &quantize_with_block_size_tma<quantization_type, T, SF_VEC_SIZE, true>
: &quantize_with_block_size_tma<quantization_type, T, SF_VEC_SIZE, false>;

// Set max dynamic shared memory for the kernel (required for > 48KB)
cudaFuncSetAttribute(kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);

cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale,
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
layout, tensor_map);
}

template <typename T, int SF_VEC_SIZE>
void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
QuantizationSFLayout layout, int multiProcessorCount, bool enable_pdl,
cudaStream_t stream) {
#ifdef ENABLE_FP8
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
// Use TMA kernel for large m (high throughput mode)
// TODO: fix the issue when n is not a multiple of NUM_CONSUMER_WARPS * TMA_COL_TILE
constexpr int TMA_COL_CHUNK = 8 * 64; // NUM_CONSUMER_WARPS * TMA_COL_TILE
if constexpr (SF_VEC_SIZE == 16) {
if (SF_VEC_SIZE == 16 && m >= 1024 && n % TMA_COL_CHUNK == 0) {
launchFP4QuantizationTma<BlockScaleQuantizationType::FP8_TO_FP4, T, SF_VEC_SIZE>(
b, m, n, input, SFScale, output, SFOuput, useUE8M0, layout, multiProcessorCount,
enable_pdl, stream);
return;
}
}
// Original non-TMA path for small m or SF_VEC_SIZE != 16
// Grid, Block size.
// Each thread converts 16 values.
dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512));
Expand All @@ -205,10 +347,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
kernel_instance<<<grid, block, 0, stream>>>(b, m, n, n, input, SFScale,
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout);

} else
#endif
{
// Use TMA kernel for large m (high throughput mode)
// TODO: fix the issue when n is not a multiple of NUM_CONSUMER_WARPS * TMA_COL_TILE
constexpr int TMA_COL_CHUNK = 8 * 64; // NUM_CONSUMER_WARPS * TMA_COL_TILE
if constexpr (SF_VEC_SIZE == 16) {
if (SF_VEC_SIZE == 16 && m >= 1024 && n % TMA_COL_CHUNK == 0) {
launchFP4QuantizationTma<BlockScaleQuantizationType::FP16_TO_FP4, T, SF_VEC_SIZE>(
b, m, n, input, SFScale, output, SFOuput, useUE8M0, layout, multiProcessorCount,
enable_pdl, stream);
return;
}
}
// Original non-TMA path for small m or SF_VEC_SIZE != 16
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512));
Expand Down
Loading
Loading