@@ -778,56 +778,84 @@ quantize_with_block_size(
778778 int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD;
779779
780780 asm volatile (" griddepcontrol.wait;" );
781+
781782 // Input tensor batch/row/col loops.
783+ // Optimization: Iterate over actual rows first (hot path), then padding rows (cold path)
784+ // This improves performance for small batch sizes with swizzled layout
782785 for (int rowIdx = blockIdx .x ; rowIdx < numPaddedRowsForSf; rowIdx += gridDim .x ) {
783- for (int batchIdx = 0 ; batchIdx < numbatches; batchIdx++) {
784- for (int colIdx = threadIdx .x ; colIdx < numColThreadsForSf; colIdx += blockDim .x ) {
785- std::optional<int > optionalBatchIdx = batchIdx;
786- std::optional<int > optionalNumRows = numRows;
787-
788- // The SF output pointer.
789- auto sf_out = cvt_quant_get_sf_out_offset<uint32_t , CVT_NUM_THREADS_PER_SF>(
790- optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
791- layout);
792-
793- // The input tensor offset.
794- int64_t inOffset =
795- static_cast <int64_t >(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
796- int64_t outOffset =
797- static_cast <int64_t >(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
798-
799- // Set the values to 0 of those are padded columns.
800- if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) {
801- // Dispatch the quantization kernel.
802- if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
803- reinterpret_cast <uint32_t *>(out)[outOffset] = 0u ;
804- } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
805- quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
806- reinterpret_cast <uint64_t *>(out)[outOffset] = 0ull ;
807- }
808- }
786+ // Early exit for padding-only blocks: if this block only processes padding rows,
787+ // we can skip the batch loop and just zero out the scale factors
788+ bool isRowPadding = (rowIdx >= numRows);
789+
790+ if (isRowPadding) {
791+ // Fast path: This row is entirely padding, only zero out scale factors
792+ for (int batchIdx = 0 ; batchIdx < numbatches; batchIdx++) {
793+ for (int colIdx = threadIdx .x ; colIdx < numColThreadsForSf; colIdx += blockDim .x ) {
794+ std::optional<int > optionalBatchIdx = batchIdx;
795+ std::optional<int > optionalNumRows = numRows;
796+
797+ // The SF output pointer.
798+ auto sf_out = cvt_quant_get_sf_out_offset<uint32_t , CVT_NUM_THREADS_PER_SF>(
799+ optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
800+ layout);
809801
810- // Set the SF padding to 0.
811- if (rowIdx >= numRows || colIdx >= numColThreads) {
812802 // Set the SF padding to 0.
813803 if (sf_out != nullptr ) {
814804 sf_out[0 ] = 0x00 ;
815805 }
816- } else {
817- // Load the input vector.
818- PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
819-
820- // Dispatch the quantization kernel.
821- if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
822- reinterpret_cast <uint32_t *>(out)[outOffset] =
823- cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
824- } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
825- reinterpret_cast <uint64_t *>(out)[outOffset] =
826- cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
827- sf_out);
828- } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
829- reinterpret_cast <uint64_t *>(out)[outOffset] =
830- cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
806+ }
807+ }
808+ } else {
809+ // Normal path: This row contains actual data
810+ for (int batchIdx = 0 ; batchIdx < numbatches; batchIdx++) {
811+ for (int colIdx = threadIdx .x ; colIdx < numColThreadsForSf; colIdx += blockDim .x ) {
812+ std::optional<int > optionalBatchIdx = batchIdx;
813+ std::optional<int > optionalNumRows = numRows;
814+
815+ // The SF output pointer.
816+ auto sf_out = cvt_quant_get_sf_out_offset<uint32_t , CVT_NUM_THREADS_PER_SF>(
817+ optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
818+ layout);
819+
820+ // The input tensor offset.
821+ int64_t inOffset =
822+ static_cast <int64_t >(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
823+ int64_t outOffset =
824+ static_cast <int64_t >(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
825+
826+ // Set the values to 0 of those are padded columns.
827+ if (colIdx >= numColThreads && colIdx < numPaddedColThreads) {
828+ // Dispatch the quantization kernel.
829+ if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
830+ reinterpret_cast <uint32_t *>(out)[outOffset] = 0u ;
831+ } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
832+ quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
833+ reinterpret_cast <uint64_t *>(out)[outOffset] = 0ull ;
834+ }
835+ }
836+
837+ // Process actual data or padding
838+ if (colIdx >= numColThreads) {
839+ // Column padding: Set the SF padding to 0.
840+ if (sf_out != nullptr ) {
841+ sf_out[0 ] = 0x00 ;
842+ }
843+ } else {
844+ // Load the input vector.
845+ PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
846+
847+ // Dispatch the quantization kernel.
848+ if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
849+ reinterpret_cast <uint32_t *>(out)[outOffset] =
850+ cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
851+ } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
852+ reinterpret_cast <uint64_t *>(out)[outOffset] =
853+ cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
854+ sf_out);
855+ } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
856+ reinterpret_cast <uint64_t *>(out)[outOffset] =
857+ cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
858+ }
831859 }
832860 }
833861 }
0 commit comments