diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index 458cafd2f6..9021bd0847 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -70,6 +70,21 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper function for grid configuration with swizzled layouts + +inline int computeEffectiveRows(int m, QuantizationSFLayout layout) { + int effectiveRows = m; + bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4); + if (isSfSwizzledLayout) { + int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; + int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile + effectiveRows = numPaddedRows; + } + return effectiveRows; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // MXFP8 Quantization @@ -85,7 +100,8 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + int effectiveRows = computeEffectiveRows(m, layout); + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. cudaLaunchConfig_t config; @@ -177,7 +193,8 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + int effectiveRows = computeEffectiveRows(m, layout); + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. auto* kernel_instance = useUE8M0 @@ -197,7 +214,8 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + int effectiveRows = computeEffectiveRows(m, layout); + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. auto* kernel_instance = useUE8M0 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh index 237b59eeaf..7abf2eb631 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh @@ -778,56 +778,86 @@ quantize_with_block_size( int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD; asm volatile("griddepcontrol.wait;"); + // Input tensor batch/row/col loops. + // Optimization: Iterate over actual rows first (hot path), then padding rows (cold path) + // This improves performance for small batch sizes with swizzled layout for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) { - for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { - for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { - std::optional optionalBatchIdx = batchIdx; - std::optional optionalNumRows = numRows; - - // The SF output pointer. - auto sf_out = cvt_quant_get_sf_out_offset( - optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, - layout); - - // The input tensor offset. - int64_t inOffset = - static_cast(batchIdx * numRows + rowIdx) * numColThreads + colIdx; - int64_t outOffset = - static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; - - // Set the values to 0 of those are padded columns. - if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) { - // Dispatch the quantization kernel. - if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { - reinterpret_cast(out)[outOffset] = 0u; - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || - quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { - reinterpret_cast(out)[outOffset] = 0ull; - } - } + // Early exit for padding-only blocks: if this block only processes padding rows, + // we can skip the batch loop and just zero out the scale factors + bool isRowPadding = (rowIdx >= numRows); + + if (isRowPadding) { + // Fast path: This row is entirely padding, only zero out scale factors. + // Note: Padding rows do NOT exist in the output tensor (which is sized [numRows, K]), + // they only exist in the swizzled scale factor layout. Do NOT write to output buffer here. + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { + std::optional optionalBatchIdx = batchIdx; + std::optional optionalNumRows = numRows; + + // The SF output pointer. + auto sf_out = cvt_quant_get_sf_out_offset( + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numColsForSf / SF_VEC_SIZE, SFout, + layout); - // Set the SF padding to 0. - if (rowIdx >= numRows || colIdx >= numColThreads) { // Set the SF padding to 0. if (sf_out != nullptr) { sf_out[0] = 0x00; } - } else { - // Load the input vector. - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - - // Dispatch the quantization kernel. - if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, - sf_out); - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + } + } + } else { + // Normal path: This row contains actual data + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { + std::optional optionalBatchIdx = batchIdx; + std::optional optionalNumRows = numRows; + + // The SF output pointer. + auto sf_out = cvt_quant_get_sf_out_offset( + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numColsForSf / SF_VEC_SIZE, SFout, + layout); + + // The input tensor offset. + int64_t inOffset = + static_cast(batchIdx * numRows + rowIdx) * numColThreads + colIdx; + int64_t outOffset = + static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; + + // Set the values to 0 of those are padded columns. + if (colIdx >= numColThreads && colIdx < numPaddedColThreads) { + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[outOffset] = 0u; + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || + quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[outOffset] = 0ull; + } + } + + // Process actual data or padding + if (colIdx >= numColThreads) { + // Column padding: Set the SF padding to 0. + if (sf_out != nullptr) { + sf_out[0] = 0x00; + } + } else { + // Load the input vector. + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, + sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + } } } }