Skip to content

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Nov 3, 2025

📌 Description

Performance optimization for fp4_quantize() function. The performance issue was raised in issues #1734 and #2021

Observed behavior was slow performance when is_sf_swizzled_layout=True (as opposed to False). Root cause of the issue was

  • Excessive Padding Overhead: Swizzled layouts require row padding to tile boundaries where SWIZZLED_128x4 pads to multiples of 128 rows and SWIZZLED_8x4 pads to multiples of 8 rows
    • This means For batch_size=1 with SWIZZLED_128x4: 127 out of 128 rows are padding (99.2% wasted work)
  • Sequential Processing: The original grid launch used grid.x = min(m, multiProcessorCount * numBlocksPerSM), so:
    For batch_size=1: only 1 block launched
  • This single block iterated sequentially over all 128 padded rows
  • Each padding row still computed scale factors, checked bounds, and performed conditional logic
  • No Fast Path: Every row (real or padding) went through the same expensive code path with multiple conditional branches

The fix:

  1. Kernel-Level Early Exit Fast Path (quantization.cuh): Added branch divergence optimization with separate handling for padding vs. data rows

    • Padding rows now execute ~10× fewer instructions; Eliminates memory loads/stores for input/output data on padding rows; Reduces register pressure and divergence overhead
  2. Host-Level Parallel Grid Launch (quantization.cu): Modified grid calculation to launch blocks proportional to padded rows instead of actual rows:

    • For batch_size=1 with SWIZZLED_128x4: launches up to 128 blocks instead of 1; Each block processes 1 row in parallel instead of sequentially; overall tries to achieve full GPU occupancy even with small batch sizes

fp4_quantize() performance before fix:

$ python3 bench_fp4_quantize.py 
+------------+---------------------+-------------------------+
| batch size | swizzled_times (us) | non_swizzled_times (us) |
+------------+---------------------+-------------------------+
|    1.0     |        71.52        |          3.136          |
|    2.0     |       37.152        |          3.168          |
|    4.0     |       19.904        |          3.168          |
|    8.0     |       11.296        |           3.2           |
|    16.0    |        7.103        |          3.296          |
|    32.0    |        4.96         |          3.376          |
|    64.0    |        4.128        |          3.487          |
|   128.0    |        3.808        |          3.648          |
|   256.0    |        4.32         |          4.161          |
|   512.0    |        5.472        |          5.184          |
+------------+---------------------+-------------------------+

After fix in current PR:

$ python3 bench_fp4_quantize.py 
+------------+---------------------+-------------------------+
| batch size | swizzled_times (us) | non_swizzled_times (us) |
+------------+---------------------+-------------------------+
|    1.0     |        3.456        |          3.264          |
|    2.0     |        3.488        |          3.296          |
|    4.0     |        3.536        |          3.296          |
|    8.0     |        3.52         |          3.296          |
|    16.0    |        3.52         |          3.456          |
|    32.0    |        3.696        |          3.488          |
|    64.0    |        3.744        |          3.584          |
|   128.0    |        3.936        |          3.776          |
|   256.0    |        4.384        |          4.288          |
|   512.0    |        5.568        |          5.248          |
+------------+---------------------+-------------------------+

where the bench_fp4_quantize.py script used to benchmark (adopted from #1734) :

from flashinfer.testing.utils import bench_gpu_time_with_cupti
from flashinfer import fp4_quantize
import torch
import numpy as np
import pandas as pd
from tabulate import tabulate

A_scale = torch.randn(16).cuda().float()
bsz = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
swizzled_times = []
for bs in bsz:
    A = torch.randn(bs, 5120).cuda().to(torch.bfloat16)
    t = np.median(bench_gpu_time_with_cupti(
            lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=True),
            dry_run_iters = 10, 
            repeat_iters = 100,
            )
        ) * 1000
    swizzled_times.append(t)

non_swizzled_times = []
for bs in bsz:
    A = torch.randn(bs, 5120).cuda().to(torch.bfloat16)
    t = np.median(bench_gpu_time_with_cupti(
        lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=False),
            dry_run_iters = 10, 
            repeat_iters = 100,
            )
        ) * 1000
    non_swizzled_times.append(t)


summary_df = pd.DataFrame({
    "batch size": bsz,
    "swizzled_times (us)": swizzled_times,
    "non_swizzled_times (us)": non_swizzled_times,
})

# Round numeric columns to three decimals before printing
summary_df_rounded = summary_df.copy()
summary_df_rounded["batch size"] = summary_df_rounded["batch size"].astype(int)
summary_df_rounded["swizzled_times (us)"] = summary_df_rounded["swizzled_times (us)"].round(3)
summary_df_rounded["non_swizzled_times (us)"] = summary_df_rounded["non_swizzled_times (us)"].round(3)
print(tabulate(summary_df_rounded, headers='keys', tablefmt='pretty', showindex=False))

🔍 Related Issues

#1734
#2021

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Improved quantization for swizzled memory layouts by adjusting how effective processing rows are computed to better utilize GPU resources.
    • Added early-exit handling for padding-only rows so padding outputs are zeroed without processing data.
    • Ensured consistent zeroing of scale/format outputs for padded columns across all quantization paths.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 3, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Grid sizing for several quantization kernels now computes effectiveRows for swizzled SF layouts and clamps to SM capacity; the per-row quantization loop was refactored to early-exit on padding-only rows (skipping data work) while zeroing SF outputs in both padding and data paths.

Changes

Cohort / File(s) Summary
Grid configuration updates
csrc/nv_internal/cpp/kernels/quantization.cu
Added computeEffectiveRows(...) and replaced prior m-based grid.x calculations in MXFP8 and FP4 quantization launch paths; rounds m up to layout tile sizes (128 or 8) and clamps to multiProcessorCount * numBlocksPerSM.
Quantization loop branching
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Refactored quantize_with_block_size to detect padding-only rows and take an early-exit path that zeros SF outputs; retained existing data-path quantization while ensuring SF padding zeroing runs in both branches.

Sequence Diagram(s)

sequenceDiagram
    participant Host
    participant KernelLauncher
    participant GPUKernel
    note right of KernelLauncher `#e8f4ff`: computeEffectiveRows(m, layout, SMs, blocksPerSM)
    Host->>KernelLauncher: request quantization (m, layout, ...)
    KernelLauncher->>GPUKernel: launch kernel with grid.x = effectiveRows
    GPUKernel->>GPUKernel: compute rowIdx
    alt rowIdx is padding
        GPUKernel->>GPUKernel: zero SF outputs for padding columns
        GPUKernel-->>GPUKernel: skip input load & quantize
    else data row
        GPUKernel->>GPUKernel: load input vector
        GPUKernel->>GPUKernel: perform quantization & write outputs
        GPUKernel->>GPUKernel: zero SF outputs for padding columns (if any)
    end
    GPUKernel->>Host: kernel completes
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Review focus:
    • computeEffectiveRows rounding/clamping for SWIZZLED_128x4 and SWIZZLED_8x4.
    • Kernel launch grid.x replacements in MXFP8 and FP4 code paths (off-by-one/rounding).
    • Padding branch correctness: row/column offsets, SF zeroing placement, and memory-write safety.
    • Divergence/occupancy impacts from added branching.

Suggested reviewers

  • djmmoss
  • yongwww
  • wenscarl
  • cyx-6

Poem

🐰
I hop through swizzled rows and beam,
Rounding tiles into a steady stream,
Padding waits — I gently clear,
Zeroed fields keep pathways near,
Kernels hum and dreams compute.

Pre-merge checks and finishing touches

✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main performance optimization: speeding up fp4 quantization for small batch sizes using swizzling for cutlass MoE, which directly aligns with the changes made.
Description check ✅ Passed The description comprehensively covers all required template sections: detailed problem explanation with performance data, root causes, implemented fixes, benchmark results, and pre-commit/test checklist completion.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26355ca and 95da76c.

📒 Files selected for processing (1)
  • csrc/nv_internal/cpp/kernels/quantization.cu (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/nv_internal/cpp/kernels/quantization.cu
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance of FP4 quantization, particularly for scenarios involving small batch sizes within Cutlass Mixture of Experts (MoE) applications. The improvements are achieved through strategic adjustments to CUDA kernel launch parameters, ensuring optimal hardware utilization for swizzled data layouts, and by refining the quantization kernel's logic to efficiently process padded data, thereby minimizing redundant computations.

Highlights

  • CUDA Grid Launch Optimization: Modified the CUDA kernel launch configuration for FP4 quantization kernels (invokeMxFP8Quantization and invokeFP4Quantization) to dynamically adjust the grid size based on 'effective rows' for swizzled layouts. This ensures better parallelism and occupancy, especially for small batch sizes where padding is significant.
  • Swizzled Layout Handling: Introduced logic to specifically handle SWIZZLED_128x4 and SWIZZLED_8x4 quantization layouts by calculating numPaddedRows to inform the grid dimension, preventing sequential processing and improving performance for small m values.
  • Optimized Padding Row Processing: Refactored the quantize_with_block_size kernel to differentiate between actual data rows and padding rows. A 'fast path' was implemented for padding-only rows, where only scale factors are zeroed out, avoiding unnecessary input loading and full quantization computations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 3, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !106 has been created, and the CI pipeline #37820809 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations for FP4 quantization, particularly for small batch sizes with swizzled layouts. The changes involve adjusting the CUDA grid dimensions to account for padded rows and refactoring the quantization kernel to handle padding rows more efficiently.

My review focuses on improving code maintainability by addressing code duplication. I've identified two areas where logic is repeated and have suggested creating helper functions or restructuring the code to eliminate this duplication. These changes should make the code cleaner and easier to maintain without affecting the performance improvements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

808-862: Refactor duplicated SF output pointer retrieval.

The SF output pointer retrieval code (lines 816-818) is duplicated from the padding path (lines 798-800). This duplication increases maintenance burden.

Consider hoisting the SF pointer retrieval outside the if (isRowPadding) branch to eliminate duplication:

+    for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
+      for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
+        std::optional<int> optionalBatchIdx = batchIdx;
+        std::optional<int> optionalNumRows = numRows;
+
+        // The SF output pointer (retrieved once for both paths).
+        auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
+            optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
+            layout);
+
     if (isRowPadding) {
-      // Fast path: This row is entirely padding, only zero out scale factors
-      for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
-        for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
-          std::optional<int> optionalBatchIdx = batchIdx;
-          std::optional<int> optionalNumRows = numRows;
-
-          // The SF output pointer.
-          auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
-              optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
-              layout);
-
-          // Set the SF padding to 0.
-          if (sf_out != nullptr) {
-            sf_out[0] = 0x00;
-          }
-        }
-      }
+      // Fast path: zero SF only
+      if (sf_out != nullptr) {
+        sf_out[0] = 0x00;
+      }
     } else {
-      // Normal path: This row contains actual data
-      for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
-        for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
-          std::optional<int> optionalBatchIdx = batchIdx;
-          std::optional<int> optionalNumRows = numRows;
-
-          // The SF output pointer.
-          auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
-              optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
-              layout);
-
+      // Normal path: process data
           // ... rest of data processing ...
+      }
+    }
+  }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between da01b1b and d771caf.

📒 Files selected for processing (2)
  • csrc/nv_internal/cpp/kernels/quantization.cu (3 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/cpp/kernels/quantization.cu (1)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h (1)
  • layout (29-47)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/nv_internal/cpp/kernels/quantization.cu (1)

88-102: Verify grid dimension calculation is correct.

The grid dimension calculation appears to apply std::min twice with the same upper bound multiProcessorCount * numBlocksPerSM, which is redundant.

Line 102 applies std::min(effectiveRows, multiProcessorCount * numBlocksPerSM), but effectiveRows is already capped at multiProcessorCount * numBlocksPerSM on line 99. The second std::min is redundant.

Apply this diff to simplify:

-  dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
+  dim3 grid(effectiveRows);

This same issue exists in lines 208 and 242 for the other two functions.

Likely an incorrect or invalid review comment.

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

781-788: Grid configuration correctly ensures padding rows are assigned to blocks.

The optimization has been properly implemented. For swizzled layouts, effectiveRows is set to std::min(PadUpFn(numRows, rowTile), SM_limit) in the host code, which expands the grid to include padded rows. Since both the host and kernel use the same PadUpFn macro (defined as ((X + Y - 1) / (Y) * (Y))), the kernel's numPaddedRowsForSf computation matches the grid sizing. This ensures gridDim.x >= numRows when m is not divisible by rowTile, allowing blocks to reach indices where rowIdx >= numRows and trigger the padding-only optimization path. The original concern about blocks not reaching padding rows has been addressed.

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 3, 2025

/bot stop

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #37820809 has been cancelled.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d771caf and 43d7e52.

📒 Files selected for processing (2)
  • csrc/nv_internal/cpp/kernels/quantization.cu (4 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/cpp/kernels/quantization.cu (1)
csrc/trtllm_gemm_runner.cu (8)
  • m (111-126)
  • m (111-111)
  • m (128-179)
  • m (128-130)
  • m (181-236)
  • m (181-181)
  • m (238-250)
  • m (238-238)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

783-784: Clarify the optimization description.

The comment states "Iterate over actual rows first (hot path), then padding rows (cold path)", but the loop at line 785 iterates sequentially from 0 to numPaddedRowsForSf. The optimization is actually an early-exit fast path for padding rows (via isRowPadding check), not a reordering of iteration.

Consider revising the comment to:

-  // Optimization: Iterate over actual rows first (hot path), then padding rows (cold path)
-  // This improves performance for small batch sizes with swizzled layout
+  // Optimization: Fast-path early exit for padding rows to skip input loading and quantization
+  // This improves performance for small batches with swizzled layout
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 43d7e52 and d060264.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

829-861: Column padding logic is correct.

The nested conditionals correctly handle three ranges:

  1. colIdx < numColThreads: Process actual data
  2. numColThreads <= colIdx < numPaddedColThreads: Zero both quantized output and SF
  3. colIdx >= numPaddedColThreads: Zero SF only (extra padding for swizzled SF layout)

Memory accesses are correctly bounded - outOffset is only used when colIdx < numPaddedColThreads.

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 3, 2025

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

829-861: Consider restructuring column handling for clarity and minor efficiency gain.

The current logic is correct but executes lines 829-837 and 840-844 for overlapping column ranges. Threads with colIdx in [numColThreads, numPaddedColThreads) zero both quantized output (lines 829-837) and SF output (lines 842-843), while threads in [numPaddedColThreads, numColThreadsForSf) only zero SF output.

Consider restructuring as non-overlapping branches:

-          // Set the values to 0 of those are padded columns.
-          if (colIdx >= numColThreads && colIdx < numPaddedColThreads) {
-            // Dispatch the quantization kernel.
-            if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
-              reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
-            } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
-                                 quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
-              reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
-            }
-          }
-
-          // Process actual data or padding
-          if (colIdx >= numColThreads) {
-            // Column padding: Set the SF padding to 0.
+          if (colIdx >= numPaddedColThreads) {
+            // SF-only padding region: zero SF output only
             if (sf_out != nullptr) {
               sf_out[0] = 0x00;
             }
+          } else if (colIdx >= numColThreads) {
+            // Quantized output padding region: zero both quantized output and SF
+            if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
+              reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
+            } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
+                                 quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
+              reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
+            }
+            if (sf_out != nullptr) {
+              sf_out[0] = 0x00;
+            }
           } else {
-            // Load the input vector.
+            // Actual data region: load input and quantize
             PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
 
             // Dispatch the quantization kernel.

This makes the three column regions explicit and avoids redundant condition checks.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d060264 and 26355ca.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (2)

790-809: Padding row handling looks correct.

The fast path correctly skips quantized output writes for padding rows (which don't exist in the output tensor) and only zeros the scale factor buffer. The SF offset calculation at line 801 correctly uses numColsForSf / SF_VEC_SIZE, ensuring proper bounds for swizzled layouts.


810-837: Offset calculations and column padding zeroing are correct.

The data path properly computes input/output offsets using the appropriate column counts (numColThreads for input, numPaddedColThreads for output), and correctly zeros quantized output for column padding. The SF offset calculation at line 819 matches the padding path in using numColsForSf / SF_VEC_SIZE.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !106 has been updated with latest changes, and the CI pipeline #37823595 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu changed the title [wip] perf: Speed up fp4 quantization for small batch with swizzling for cutlass MoE perf: Speed up fp4 quantization for small batch with swizzling for cutlass MoE Nov 3, 2025
@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37823595: 12/17 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive speedup and the separation of hot path and cold path looks reasonable to me, thanks for this effort!

The failed gb200 ut is not relevant.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 4, 2025

cc @djmmoss @yongwww @wenscarl for another look

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 4, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !106 has been updated with latest changes, and the CI pipeline #37898618 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 4, 2025

/bot stop

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #37898618 has been cancelled.

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 4, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !106 has been created, and the CI pipeline #37898689 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #37898689: 13/17 passed

@wenscarl wenscarl self-requested a review November 5, 2025 04:04
@yzh119 yzh119 merged commit 2580610 into flashinfer-ai:main Nov 5, 2025
4 checks passed
@yzh119 yzh119 mentioned this pull request Nov 5, 2025
32 tasks
@bkryu bkryu deleted the fp4_quantization_fix branch November 7, 2025 00:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants