Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions csrc/nv_internal/cpp/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,21 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));

// For swizzled layout, we need to consider padded rows to avoid sequential processing
// This is critical for small batch sizes where m << padding
int effectiveRows = m;
bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 ||
layout == QuantizationSFLayout::SWIZZLED_8x4);
if (isSfSwizzledLayout) {
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile
// Use padded rows for grid calculation, but cap at reasonable limit
// to balance parallelism with occupancy
effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM);
}

dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));

// Launch the cvt kernel.
cudaLaunchConfig_t config;
Expand Down Expand Up @@ -177,7 +191,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));

// For swizzled layout, we need to consider padded rows to avoid sequential processing
// This is critical for small batch sizes where m << padding
int effectiveRows = m;
bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 ||
layout == QuantizationSFLayout::SWIZZLED_8x4);
if (isSfSwizzledLayout) {
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile
// Use padded rows for grid calculation, but cap at reasonable limit
// to balance parallelism with occupancy
effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM);
}

dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));

// Launch the cvt kernel.
auto* kernel_instance = useUE8M0
Expand All @@ -197,7 +225,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));

// For swizzled layout, we need to consider padded rows to avoid sequential processing
// This is critical for small batch sizes where m << padding
int effectiveRows = m;
bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 ||
layout == QuantizationSFLayout::SWIZZLED_8x4);
if (isSfSwizzledLayout) {
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile
// Use padded rows for grid calculation, but cap at reasonable limit
// to balance parallelism with occupancy
effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM);
}

dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));

// Launch the cvt kernel.
auto* kernel_instance = useUE8M0
Expand Down
114 changes: 71 additions & 43 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -778,56 +778,84 @@ quantize_with_block_size(
int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD;

asm volatile("griddepcontrol.wait;");

// Input tensor batch/row/col loops.
// Optimization: Iterate over actual rows first (hot path), then padding rows (cold path)
// This improves performance for small batch sizes with swizzled layout
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) {
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
std::optional<int> optionalBatchIdx = batchIdx;
std::optional<int> optionalNumRows = numRows;

// The SF output pointer.
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
layout);

// The input tensor offset.
int64_t inOffset =
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
int64_t outOffset =
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;

// Set the values to 0 of those are padded columns.
if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) {
// Dispatch the quantization kernel.
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
}
}
// Early exit for padding-only blocks: if this block only processes padding rows,
// we can skip the batch loop and just zero out the scale factors
bool isRowPadding = (rowIdx >= numRows);

if (isRowPadding) {
// Fast path: This row is entirely padding, only zero out scale factors
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
std::optional<int> optionalBatchIdx = batchIdx;
std::optional<int> optionalNumRows = numRows;

// The SF output pointer.
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
layout);

// Set the SF padding to 0.
if (rowIdx >= numRows || colIdx >= numColThreads) {
// Set the SF padding to 0.
if (sf_out != nullptr) {
sf_out[0] = 0x00;
}
} else {
// Load the input vector.
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];

// Dispatch the quantization kernel.
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
reinterpret_cast<uint32_t*>(out)[outOffset] =
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
reinterpret_cast<uint64_t*>(out)[outOffset] =
cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
sf_out);
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
reinterpret_cast<uint64_t*>(out)[outOffset] =
cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
}
}
} else {
// Normal path: This row contains actual data
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
std::optional<int> optionalBatchIdx = batchIdx;
std::optional<int> optionalNumRows = numRows;

// The SF output pointer.
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
layout);

// The input tensor offset.
int64_t inOffset =
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
int64_t outOffset =
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;

// Set the values to 0 of those are padded columns.
if (colIdx >= numColThreads && colIdx < numPaddedColThreads) {
// Dispatch the quantization kernel.
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
}
}

// Process actual data or padding
if (colIdx >= numColThreads) {
// Column padding: Set the SF padding to 0.
if (sf_out != nullptr) {
sf_out[0] = 0x00;
}
} else {
// Load the input vector.
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];

// Dispatch the quantization kernel.
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
reinterpret_cast<uint32_t*>(out)[outOffset] =
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
reinterpret_cast<uint64_t*>(out)[outOffset] =
cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
sf_out);
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
reinterpret_cast<uint64_t*>(out)[outOffset] =
cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
}
}
}
}
Expand Down