From d771caf88f66c4b2a8e88f41b63a175e6b2f751e Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Mon, 3 Nov 2025 17:44:37 +0000 Subject: [PATCH 1/5] Speed up fp4 quantization for small batch with swizzling --- csrc/nv_internal/cpp/kernels/quantization.cu | 48 +++++++- .../tensorrt_llm/kernels/quantization.cuh | 114 +++++++++++------- 2 files changed, 116 insertions(+), 46 deletions(-) diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index 458cafd2f6..7f0f7cf355 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -85,7 +85,21 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // For swizzled layout, we need to consider padded rows to avoid sequential processing + // This is critical for small batch sizes where m << padding + int effectiveRows = m; + bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4); + if (isSfSwizzledLayout) { + int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; + int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile + // Use padded rows for grid calculation, but cap at reasonable limit + // to balance parallelism with occupancy + effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); + } + + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. cudaLaunchConfig_t config; @@ -177,7 +191,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // For swizzled layout, we need to consider padded rows to avoid sequential processing + // This is critical for small batch sizes where m << padding + int effectiveRows = m; + bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4); + if (isSfSwizzledLayout) { + int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; + int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile + // Use padded rows for grid calculation, but cap at reasonable limit + // to balance parallelism with occupancy + effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); + } + + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. auto* kernel_instance = useUE8M0 @@ -197,7 +225,21 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // For swizzled layout, we need to consider padded rows to avoid sequential processing + // This is critical for small batch sizes where m << padding + int effectiveRows = m; + bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4); + if (isSfSwizzledLayout) { + int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; + int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile + // Use padded rows for grid calculation, but cap at reasonable limit + // to balance parallelism with occupancy + effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); + } + + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. auto* kernel_instance = useUE8M0 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh index 237b59eeaf..a6cac99dca 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh @@ -778,56 +778,84 @@ quantize_with_block_size( int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD; asm volatile("griddepcontrol.wait;"); + // Input tensor batch/row/col loops. + // Optimization: Iterate over actual rows first (hot path), then padding rows (cold path) + // This improves performance for small batch sizes with swizzled layout for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) { - for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { - for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { - std::optional optionalBatchIdx = batchIdx; - std::optional optionalNumRows = numRows; - - // The SF output pointer. - auto sf_out = cvt_quant_get_sf_out_offset( - optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, - layout); - - // The input tensor offset. - int64_t inOffset = - static_cast(batchIdx * numRows + rowIdx) * numColThreads + colIdx; - int64_t outOffset = - static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; - - // Set the values to 0 of those are padded columns. - if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) { - // Dispatch the quantization kernel. - if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { - reinterpret_cast(out)[outOffset] = 0u; - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || - quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { - reinterpret_cast(out)[outOffset] = 0ull; - } - } + // Early exit for padding-only blocks: if this block only processes padding rows, + // we can skip the batch loop and just zero out the scale factors + bool isRowPadding = (rowIdx >= numRows); + + if (isRowPadding) { + // Fast path: This row is entirely padding, only zero out scale factors + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { + std::optional optionalBatchIdx = batchIdx; + std::optional optionalNumRows = numRows; + + // The SF output pointer. + auto sf_out = cvt_quant_get_sf_out_offset( + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, + layout); - // Set the SF padding to 0. - if (rowIdx >= numRows || colIdx >= numColThreads) { // Set the SF padding to 0. if (sf_out != nullptr) { sf_out[0] = 0x00; } - } else { - // Load the input vector. - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - - // Dispatch the quantization kernel. - if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, - sf_out); - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + } + } + } else { + // Normal path: This row contains actual data + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { + std::optional optionalBatchIdx = batchIdx; + std::optional optionalNumRows = numRows; + + // The SF output pointer. + auto sf_out = cvt_quant_get_sf_out_offset( + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, + layout); + + // The input tensor offset. + int64_t inOffset = + static_cast(batchIdx * numRows + rowIdx) * numColThreads + colIdx; + int64_t outOffset = + static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; + + // Set the values to 0 of those are padded columns. + if (colIdx >= numColThreads && colIdx < numPaddedColThreads) { + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[outOffset] = 0u; + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || + quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[outOffset] = 0ull; + } + } + + // Process actual data or padding + if (colIdx >= numColThreads) { + // Column padding: Set the SF padding to 0. + if (sf_out != nullptr) { + sf_out[0] = 0x00; + } + } else { + // Load the input vector. + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, + sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + } } } } From 43d7e52cb1de130bafe24ef8b4819b0c557b972d Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Mon, 3 Nov 2025 18:10:18 +0000 Subject: [PATCH 2/5] Address review comments --- csrc/nv_internal/cpp/kernels/quantization.cu | 63 +++++++------------ .../tensorrt_llm/kernels/quantization.cuh | 14 ++++- 2 files changed, 34 insertions(+), 43 deletions(-) diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index 7f0f7cf355..acb9acfd14 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -70,6 +70,24 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper function for grid configuration with swizzled layouts + +inline int computeEffectiveRows(int m, QuantizationSFLayout layout, int multiProcessorCount, + int numBlocksPerSM) { + int effectiveRows = m; + bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4); + if (isSfSwizzledLayout) { + int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; + int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile + // Use padded rows for grid calculation, but cap at reasonable limit + // to balance parallelism with occupancy + effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); + } + return effectiveRows; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // MXFP8 Quantization @@ -85,20 +103,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - - // For swizzled layout, we need to consider padded rows to avoid sequential processing - // This is critical for small batch sizes where m << padding - int effectiveRows = m; - bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || - layout == QuantizationSFLayout::SWIZZLED_8x4); - if (isSfSwizzledLayout) { - int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; - int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile - // Use padded rows for grid calculation, but cap at reasonable limit - // to balance parallelism with occupancy - effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); - } - + int effectiveRows = computeEffectiveRows(m, layout, multiProcessorCount, numBlocksPerSM); dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. @@ -191,20 +196,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - - // For swizzled layout, we need to consider padded rows to avoid sequential processing - // This is critical for small batch sizes where m << padding - int effectiveRows = m; - bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || - layout == QuantizationSFLayout::SWIZZLED_8x4); - if (isSfSwizzledLayout) { - int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; - int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile - // Use padded rows for grid calculation, but cap at reasonable limit - // to balance parallelism with occupancy - effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); - } - + int effectiveRows = computeEffectiveRows(m, layout, multiProcessorCount, numBlocksPerSM); dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. @@ -225,20 +217,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - - // For swizzled layout, we need to consider padded rows to avoid sequential processing - // This is critical for small batch sizes where m << padding - int effectiveRows = m; - bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || - layout == QuantizationSFLayout::SWIZZLED_8x4); - if (isSfSwizzledLayout) { - int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; - int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile - // Use padded rows for grid calculation, but cap at reasonable limit - // to balance parallelism with occupancy - effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); - } - + int effectiveRows = computeEffectiveRows(m, layout, multiProcessorCount, numBlocksPerSM); dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh index a6cac99dca..c03ee1b32d 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh @@ -788,7 +788,7 @@ quantize_with_block_size( bool isRowPadding = (rowIdx >= numRows); if (isRowPadding) { - // Fast path: This row is entirely padding, only zero out scale factors + // Fast path: This row is entirely padding, zero out both quantized output and scale factors for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { std::optional optionalBatchIdx = batchIdx; @@ -799,6 +799,18 @@ quantize_with_block_size( optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout); + // Zero the quantized output for all columns (both actual and padded column range) + if (colIdx < numPaddedColThreads) { + int64_t outOffset = + static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[outOffset] = 0u; + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || + quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[outOffset] = 0ull; + } + } + // Set the SF padding to 0. if (sf_out != nullptr) { sf_out[0] = 0x00; From d0602640feec0468cb84722d85daa3b79fe80c29 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Mon, 3 Nov 2025 18:17:33 +0000 Subject: [PATCH 3/5] Fix unit test failure --- .../tensorrt_llm/kernels/quantization.cuh | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh index c03ee1b32d..4a2b15f8d0 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh @@ -788,7 +788,9 @@ quantize_with_block_size( bool isRowPadding = (rowIdx >= numRows); if (isRowPadding) { - // Fast path: This row is entirely padding, zero out both quantized output and scale factors + // Fast path: This row is entirely padding, only zero out scale factors. + // Note: Padding rows do NOT exist in the output tensor (which is sized [numRows, K]), + // they only exist in the swizzled scale factor layout. Do NOT write to output buffer here. for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { std::optional optionalBatchIdx = batchIdx; @@ -799,18 +801,6 @@ quantize_with_block_size( optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout); - // Zero the quantized output for all columns (both actual and padded column range) - if (colIdx < numPaddedColThreads) { - int64_t outOffset = - static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; - if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { - reinterpret_cast(out)[outOffset] = 0u; - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || - quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { - reinterpret_cast(out)[outOffset] = 0ull; - } - } - // Set the SF padding to 0. if (sf_out != nullptr) { sf_out[0] = 0x00; From 26355ca43feba8c9c52d329c5c22e3b7c5c08c0c Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Mon, 3 Nov 2025 18:31:56 +0000 Subject: [PATCH 4/5] Address reviewed comments --- csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh index 4a2b15f8d0..7abf2eb631 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh @@ -798,7 +798,7 @@ quantize_with_block_size( // The SF output pointer. auto sf_out = cvt_quant_get_sf_out_offset( - optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numColsForSf / SF_VEC_SIZE, SFout, layout); // Set the SF padding to 0. @@ -816,7 +816,7 @@ quantize_with_block_size( // The SF output pointer. auto sf_out = cvt_quant_get_sf_out_offset( - optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numColsForSf / SF_VEC_SIZE, SFout, layout); // The input tensor offset. From 95da76cd7a3b9ea8f83f6d6c186934a84792ff1a Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 4 Nov 2025 22:46:11 +0000 Subject: [PATCH 5/5] Address comments --- csrc/nv_internal/cpp/kernels/quantization.cu | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index acb9acfd14..9021bd0847 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -73,17 +73,14 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const //////////////////////////////////////////////////////////////////////////////////////////////////// // Helper function for grid configuration with swizzled layouts -inline int computeEffectiveRows(int m, QuantizationSFLayout layout, int multiProcessorCount, - int numBlocksPerSM) { +inline int computeEffectiveRows(int m, QuantizationSFLayout layout) { int effectiveRows = m; bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || layout == QuantizationSFLayout::SWIZZLED_8x4); if (isSfSwizzledLayout) { int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile - // Use padded rows for grid calculation, but cap at reasonable limit - // to balance parallelism with occupancy - effectiveRows = std::min(numPaddedRows, multiProcessorCount * numBlocksPerSM); + effectiveRows = numPaddedRows; } return effectiveRows; } @@ -103,7 +100,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - int effectiveRows = computeEffectiveRows(m, layout, multiProcessorCount, numBlocksPerSM); + int effectiveRows = computeEffectiveRows(m, layout); dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. @@ -196,7 +193,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - int effectiveRows = computeEffectiveRows(m, layout, multiProcessorCount, numBlocksPerSM); + int effectiveRows = computeEffectiveRows(m, layout); dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. @@ -217,7 +214,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - int effectiveRows = computeEffectiveRows(m, layout, multiProcessorCount, numBlocksPerSM); + int effectiveRows = computeEffectiveRows(m, layout); dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel.