Skip to content

Commit d771caf

Browse files
committed
Speed up fp4 quantization for small batch with swizzling
1 parent da01b1b commit d771caf

File tree

2 files changed

+116
-46
lines changed

2 files changed

+116
-46
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,21 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
8585
dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512));
8686
// Get number of blocks per SM (assume we can fully utilize the SM).
8787
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
88-
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
88+
89+
// For swizzled layout, we need to consider padded rows to avoid sequential processing
90+
// This is critical for small batch sizes where m << padding
91+
int effectiveRows = m;
92+
bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 ||
93+
layout == QuantizationSFLayout::SWIZZLED_8x4);
94+
if (isSfSwizzledLayout) {
95+
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
96+
int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile
97+
// Use padded rows for grid calculation, but cap at reasonable limit
98+
// to balance parallelism with occupancy
99+
effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM);
100+
}
101+
102+
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
89103

90104
// Launch the cvt kernel.
91105
cudaLaunchConfig_t config;
@@ -177,7 +191,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
177191
dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512));
178192
// Get number of blocks per SM (assume we can fully utilize the SM).
179193
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
180-
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
194+
195+
// For swizzled layout, we need to consider padded rows to avoid sequential processing
196+
// This is critical for small batch sizes where m << padding
197+
int effectiveRows = m;
198+
bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 ||
199+
layout == QuantizationSFLayout::SWIZZLED_8x4);
200+
if (isSfSwizzledLayout) {
201+
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
202+
int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile
203+
// Use padded rows for grid calculation, but cap at reasonable limit
204+
// to balance parallelism with occupancy
205+
effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM);
206+
}
207+
208+
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
181209

182210
// Launch the cvt kernel.
183211
auto* kernel_instance = useUE8M0
@@ -197,7 +225,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
197225
dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512));
198226
// Get number of blocks per SM (assume we can fully utilize the SM).
199227
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
200-
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
228+
229+
// For swizzled layout, we need to consider padded rows to avoid sequential processing
230+
// This is critical for small batch sizes where m << padding
231+
int effectiveRows = m;
232+
bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 ||
233+
layout == QuantizationSFLayout::SWIZZLED_8x4);
234+
if (isSfSwizzledLayout) {
235+
int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8;
236+
int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile
237+
// Use padded rows for grid calculation, but cap at reasonable limit
238+
// to balance parallelism with occupancy
239+
effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM);
240+
}
241+
242+
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
201243

202244
// Launch the cvt kernel.
203245
auto* kernel_instance = useUE8M0

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -778,56 +778,84 @@ quantize_with_block_size(
778778
int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD;
779779

780780
asm volatile("griddepcontrol.wait;");
781+
781782
// Input tensor batch/row/col loops.
783+
// Optimization: Iterate over actual rows first (hot path), then padding rows (cold path)
784+
// This improves performance for small batch sizes with swizzled layout
782785
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) {
783-
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
784-
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
785-
std::optional<int> optionalBatchIdx = batchIdx;
786-
std::optional<int> optionalNumRows = numRows;
787-
788-
// The SF output pointer.
789-
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
790-
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
791-
layout);
792-
793-
// The input tensor offset.
794-
int64_t inOffset =
795-
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
796-
int64_t outOffset =
797-
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
798-
799-
// Set the values to 0 of those are padded columns.
800-
if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) {
801-
// Dispatch the quantization kernel.
802-
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
803-
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
804-
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
805-
quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
806-
reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
807-
}
808-
}
786+
// Early exit for padding-only blocks: if this block only processes padding rows,
787+
// we can skip the batch loop and just zero out the scale factors
788+
bool isRowPadding = (rowIdx >= numRows);
789+
790+
if (isRowPadding) {
791+
// Fast path: This row is entirely padding, only zero out scale factors
792+
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
793+
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
794+
std::optional<int> optionalBatchIdx = batchIdx;
795+
std::optional<int> optionalNumRows = numRows;
796+
797+
// The SF output pointer.
798+
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
799+
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
800+
layout);
809801

810-
// Set the SF padding to 0.
811-
if (rowIdx >= numRows || colIdx >= numColThreads) {
812802
// Set the SF padding to 0.
813803
if (sf_out != nullptr) {
814804
sf_out[0] = 0x00;
815805
}
816-
} else {
817-
// Load the input vector.
818-
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
819-
820-
// Dispatch the quantization kernel.
821-
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
822-
reinterpret_cast<uint32_t*>(out)[outOffset] =
823-
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
824-
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
825-
reinterpret_cast<uint64_t*>(out)[outOffset] =
826-
cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
827-
sf_out);
828-
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
829-
reinterpret_cast<uint64_t*>(out)[outOffset] =
830-
cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
806+
}
807+
}
808+
} else {
809+
// Normal path: This row contains actual data
810+
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
811+
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
812+
std::optional<int> optionalBatchIdx = batchIdx;
813+
std::optional<int> optionalNumRows = numRows;
814+
815+
// The SF output pointer.
816+
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
817+
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
818+
layout);
819+
820+
// The input tensor offset.
821+
int64_t inOffset =
822+
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
823+
int64_t outOffset =
824+
static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
825+
826+
// Set the values to 0 of those are padded columns.
827+
if (colIdx >= numColThreads && colIdx < numPaddedColThreads) {
828+
// Dispatch the quantization kernel.
829+
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
830+
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
831+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
832+
quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
833+
reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
834+
}
835+
}
836+
837+
// Process actual data or padding
838+
if (colIdx >= numColThreads) {
839+
// Column padding: Set the SF padding to 0.
840+
if (sf_out != nullptr) {
841+
sf_out[0] = 0x00;
842+
}
843+
} else {
844+
// Load the input vector.
845+
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
846+
847+
// Dispatch the quantization kernel.
848+
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
849+
reinterpret_cast<uint32_t*>(out)[outOffset] =
850+
cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
851+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
852+
reinterpret_cast<uint64_t*>(out)[outOffset] =
853+
cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
854+
sf_out);
855+
} else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
856+
reinterpret_cast<uint64_t*>(out)[outOffset] =
857+
cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
858+
}
831859
}
832860
}
833861
}

0 commit comments

Comments
 (0)