Skip to content

feat: Add CuTe-DSL backend for NVFP4 quantization#2838

Merged
bkryu merged 14 commits intoflashinfer-ai:mainfrom
bkryu:cute_dsl_nvfp4_quant
Mar 26, 2026
Merged

feat: Add CuTe-DSL backend for NVFP4 quantization#2838
bkryu merged 14 commits intoflashinfer-ai:mainfrom
bkryu:cute_dsl_nvfp4_quant

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Mar 20, 2026

📌 Description

  • Adds backend="cute-dsl" support to nvfp4_quantize with two kernel variants:
    • Default kernel: vectorized global loads (ld.global.v4.u32), optimal for small-to-medium problems
    • TMA kernel: producer-consumer warp specialization (1 producer + 8 consumer warps), 3D TMA with SWIZZLE_128B, optimal for large problems (M×K >= 2^25 elements)
  • Auto-dispatches between variants based on log2(M) + log2(K) >= 25 threshold
  • Supports all SF layouts (128x4, 8x4, linear), fp16/bf16 input dtypes, and PDL

Performance

  • 1.16x geometric mean speedup over the CUDA backend across 99 (M, K) configurations on B200
  • Faster in 85/99 cases, worst case 0.98x
  • Achieves up to 6.7 TB/s memory bandwidth (84% of B200 peak)

Performance Sweeps in M (rows) and K (cols) space comparing the two backends:
B200:

================================================================================                                                                                                                                                                                                                                  
Summary: CuTe-DSL speedup over CUDA  (>1 = CuTe-DSL faster)                                                                                                                                                                                                                                                       
================================================================================                                                                                                                                                                                                                                  
M\K          512    1024    2048    4096    6144    8192   12288   16384   32768                                                                                                                                                                                                                                  
--------------------------------------------------------------------------------                                                                                                                                                                                                                                  
128         1.10    1.13    1.14    1.22    1.43    1.43    1.34    1.52    1.55                                                                                                                                                                                                                                  
256         1.12    1.14    1.19    1.24    1.41    1.38    1.31    1.43    1.43                                                                                                                                                                                                                                  
512         1.18    1.17    1.27    1.39    1.45    1.40    1.28    1.35    1.31                                                                                                                                                                                                                                  
1024        1.18    1.18    1.22    1.24    1.30    1.31    1.35    1.37    1.34                                                                                                                                                                                                                                  
2048        1.16    1.17    1.21    1.17    1.20    1.15    1.17    1.16    1.11                                                                                                                                                                                                                                  
4096        1.14    1.14    1.16    1.10    1.09    1.09    1.07    1.06    1.04                                                                                                                                                                                                                                  
8192        1.19    1.14    1.14    1.10    1.08    1.08    1.06    1.05    1.02                                                                                                                                                                                                                                  
16384       1.25    1.16    1.10    1.10    1.10    1.11    1.09    1.09    1.10                                                                                                                                                                                                                                  
32768       1.25    1.06    1.02    1.02    1.01    1.02    1.00    0.99    1.00                                                                                                                                                                                                                                  
65536       1.21    1.07    0.99    1.00    0.99    0.99    0.99    0.99    0.99                                                                                                                                                                                                                                  
131072      1.37    1.09    0.99    0.98    0.98    0.98    0.98    1.01    0.99                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                                                  
Geometric mean: 1.15x                                                                                                                                                                                                                                                                                             
Min: 0.98x   Max: 1.55x                                                                                                                                                                                                                                                                                           
Cases where CuTe-DSL faster: 85/99 

RTX PRO 6000 Workstation:

================================================================================
Summary: CuTe-DSL speedup over CUDA  (>1 = CuTe-DSL faster)
================================================================================
M\K          512    1024    2048    4096    6144    8192   12288   16384   32768
--------------------------------------------------------------------------------
128         1.10    1.05    0.92    0.95    1.52    1.15    1.10    1.19    1.11
256         1.04    0.94    1.03    1.06    1.27    1.27    1.05    1.07    1.06
512         0.98    0.95    0.98    1.01    1.12    1.10    1.09    1.03    1.06
1024        1.32    1.34    1.35    1.28    1.34    1.22    1.13    1.03    0.99
2048        1.33    1.33    1.16    1.08    1.14    1.06    1.01    1.00    1.01
4096        1.51    1.35    1.15    1.04    1.04    1.01    1.00    1.00    1.00
8192        1.41    1.16    1.04    1.02    1.00    1.01    1.00    1.00    1.00
16384       1.16    1.05    1.01    1.01    1.00    1.00    1.00    1.00    1.00
32768       1.03    1.02    1.01    1.00    1.00    1.00    1.00    1.00    1.00
65536       1.03    1.01    1.00    1.00    1.00    1.00    1.00    1.00    1.00
131072      1.01    1.01    1.00    1.00    1.01    1.00    1.00    1.00    0.99

Geometric mean: 1.07x
Min: 0.92x   Max: 1.52x
Cases where CuTe-DSL faster: 86/99

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added CuTe-DSL backend support for NVFP4/MXFP4 FP4 quantization with experimental backend selection.
    • Introduced configurable scale-factor layouts (128x4, 8x4, linear) and new layout-aware quantization paths.
    • Extended public quantization APIs with a backend parameter and CuTe-DSL execution paths.
  • Tests

    • Added comprehensive backend-parametrized tests exercising CUDA vs CuTe-DSL parity, NVFP4 layouts, FP8 inputs, and large-shape TMA scenarios.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

This PR adds CuTe‑DSL as a second backend for NVFP4 quantization, implements new CuTe‑DSL NVFP4 kernels and FP4/CUTE helpers, extends fp4_quantize/nvfp4_quantize APIs with a backend parameter and dispatch, updates MXFP4 kernel layout handling, and expands tests and benchmark mappings for multi‑backend validation.

Changes

Cohort / File(s) Summary
Benchmarks & Tests
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/routines/quantization.py, tests/utils/test_fp4_quantize.py
Added cute-dsl to NVFP4 backend support, forwarded backend param to nvfp4_quantize, parametrized and extended FP4/NVFP4 tests to exercise cuda and cute-dsl.
CuTe‑DSL helpers
flashinfer/cute_dsl/fp4_common.py, flashinfer/quantization/quantization_cute_dsl_utils.py
Added multiple CuTe‑DSL user ops and FP4/E2M1 conversion/packing utilities, SF index helpers, and a rewritten float→ue8m0 fast path.
Quantization API & Dispatch
flashinfer/quantization/fp4_quantization.py, flashinfer/quantization/__init__.py
Added backend parameter to fp4_quantize/nvfp4_quantize, implemented CuTe‑DSL dispatch and validation, exposed nvfp4_quantize_cute_dsl conditionally.
NVFP4 CuTe‑DSL kernels (new)
flashinfer/quantization/kernels/nvfp4_quantize.py
New comprehensive NVFP4 CuTe‑DSL implementation: SF layout constants, swizzled and TMA kernels, compilation/caching helpers, entry point nvfp4_quantize_cute_dsl.
MXFP4 kernel generalization
flashinfer/quantization/kernels/mxfp4_quantize.py, flashinfer/quantization/kernels/__init__.py
Generalized MXFP4 kernel to configurable sf_layout, added layout constants, adjusted compilation cache keys and public exports.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test/Client
    participant FP4API as fp4_quantize API
    participant NVFPDispatch as nvfp4_quantize (dispatcher)
    participant CudaKernel as CUDA implementation
    participant CuteDslKernel as CuTe‑DSL implementation

    Test->>FP4API: fp4_quantize(input,..., backend="cuda"|"cute-dsl")
    FP4API->>NVFPDispatch: nvfp4_quantize(..., backend)
    alt backend == "cuda"
        NVFPDispatch->>CudaKernel: execute CUDA path
        CudaKernel-->>NVFPDispatch: (fp4_output, scale)
    else backend == "cute-dsl"
        NVFPDispatch->>NVFPDispatch: validate CuTe‑DSL, map sf_layout
        NVFPDispatch->>CuteDslKernel: select & invoke (Swizzled|TMA)
        CuteDslKernel-->>NVFPDispatch: (fp4_output, scale)
    end
    NVFPDispatch-->>FP4API: return (fp4_output, scale)
    FP4API-->>Test: result
Loading
sequenceDiagram
    participant Kernel as NVFP4QuantizeSwizzledKernel
    participant Scale as Scale computation
    participant Output as Output writer

    Kernel->>Kernel: Partition into 16‑elem blocks
    loop per block
        Kernel->>Scale: compute block max / derive E2M1 scale
        Scale->>Kernel: normalized scale, pack to E2M1
        Kernel->>Output: write packed FP4 bytes and scale factor at layout offset
    end
    Output-->>Kernel: quantized outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related issues

Possibly related PRs

Suggested labels

cute-dsl, benchmark, ready

Suggested reviewers

  • aleozlx
  • nvmbreughe
  • cyx-6
  • jimmyzho
  • yzh119
  • nv-yunzheq
  • kahyunnam

Poem

🐇 I hopped from CUDA's tidy trail to try,
New CuTe‑DSL fields beneath the sky,
Pack E2M1, shuffle scales with flair,
Two backends now, both doing their share,
A tiny rabbit’s cheer — quantize, hop, and fly! 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.93% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding CuTe-DSL backend support for NVFP4 quantization, which aligns with the primary objective of the PR.
Description check ✅ Passed The PR description includes a detailed summary of changes, performance results with sweep tables, and confirms completion of pre-commit checks and testing requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized CuTe-DSL backend for NVFP4 quantization, aiming to boost performance and expand functionality. It incorporates advanced GPU programming techniques like TMA-based memory access and intelligent kernel dispatching to achieve substantial speedups across a wide range of problem sizes. The changes also enhance the flexibility of existing MXFP4 kernels and ensure robust operation through comprehensive testing.

Highlights

  • CuTe-DSL Backend for NVFP4 Quantization: Introduced a new CuTe-DSL backend for NVFP4 quantization, offering significant performance improvements over the existing CUDA backend, with a 1.16x geometric mean speedup across various configurations on B200.
  • Dual Kernel Variants and Auto-Dispatch: Implemented two kernel variants: a default kernel using vectorized global loads optimized for small-to-medium problems, and a TMA (Tensor Memory Accelerator) kernel with producer-consumer warp specialization for large problems. The system automatically dispatches between these variants based on problem size (log2(M) + log2(K) >= 25 threshold).
  • Expanded Layout and Dtype Support: The new CuTe-DSL backend supports all standard SF (Scale Factor) layouts (128x4, 8x4, linear), various input dtypes including fp16, bf16, and fp8 (float8_e4m3fn), and PDL (Programmatic Dependent Launch).
  • Improved MXFP4 Kernel Flexibility: The existing MXFP4 CuTe-DSL kernel was refactored to support multiple scale factor layouts (swizzled 128x4 and linear), enhancing its versatility.
  • Comprehensive Testing: Added extensive unit tests to verify the correctness and parity of the CuTe-DSL backend against the CUDA backend for NVFP4 quantization, including roundtrip tests, backend parity checks, FP8 input tests, and specific tests for the TMA kernel.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new, high-performance CuTe-DSL backend for NVFP4 quantization, complete with two kernel variants (default and TMA-based) and automatic dispatching. The implementation is comprehensive, adding new low-level PTX wrappers, new kernels, and refactoring existing ones for better generality. The performance gains are substantial. The accompanying tests are thorough, ensuring correctness and parity with the existing CUDA backend. I've identified a potential bug in the scale factor layout handling when shuffling is enabled for the new backend, along with an opportunity for refactoring to reduce code duplication.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (3)
tests/utils/test_fp4_quantize.py (2)

136-138: Move _is_cute_dsl_available() definition before first usage.

The helper function _is_cute_dsl_available() is called here but defined later at line 353. While Python resolves this at runtime, it harms readability. Consider moving the definition (lines 353-360) before its first usage, perhaps near the other helper functions like _is_fp4_supported() at line 27.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_fp4_quantize.py` around lines 136 - 138, Move the helper
function _is_cute_dsl_available() so it appears before its first call in
tests/utils/test_fp4_quantize.py; specifically, relocate the function definition
(currently at lines ~353-360) up into the helper section near
_is_fp4_supported() (around line ~27) so that the call inside the backend check
(if backend == "cute-dsl": if not _is_cute_dsl_available(): pytest.skip(...))
references a previously defined function, preserving existing name/signature and
any imports used by _is_cute_dsl_available().

666-672: Same exact-equality concern as MXFP4 parity test.

Consider aligning with a small tolerance if flaky failures occur, or document why bit-exact parity is required here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_fp4_quantize.py` around lines 666 - 672, The test uses
torch.testing.assert_close with rtol=0 and atol=0 comparing dq_cuda and dq_cute
which enforces bit-exact equality and may cause flaky failures; update the
assertion in tests/utils/test_fp4_quantize.py to allow a small numeric tolerance
(e.g., set a non-zero rtol and/or atol) or add a clear comment documenting why
bit-exact parity is required; specifically modify the assertion call for
torch.testing.assert_close(dq_cuda, dq_cute, rtol=..., atol=..., msg=error_msg)
to use appropriate tolerances or add an explanatory comment near
dq_cuda/dq_cute/error_msg explaining the rationale.
flashinfer/quantization/fp4_quantization.py (1)

742-750: Parameter name input shadows Python builtin.

The parameter input shadows the Python builtin function. While functional, this can cause subtle issues if the builtin is needed within the function.

Suggested rename
 def _fp4_quantize_cute_dsl(
-    input: torch.Tensor,
+    x: torch.Tensor,
     global_scale: Optional[torch.Tensor],
     sf_vec_size: int,
     sf_use_ue8m0: bool,
     is_sf_swizzled_layout: bool,
     is_sf_8x4_layout: bool,
     enable_pdl: Optional[bool],
 ) -> Tuple[torch.Tensor, torch.Tensor]:

(And update all references to input within the function to x)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 742 - 750, The
parameter name `input` in function _fp4_quantize_cute_dsl shadows the Python
builtin; rename the parameter to `x` (update the signature type hint from
`input: torch.Tensor` to `x: torch.Tensor`) and update every reference inside
the function body (all reads/writes, any slices, clones, and variable
assignments that use `input`) to use `x` instead so the builtin is not shadowed
and behavior remains identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 991-1002: The CuTe-DSL branch calls nvfp4_quantize_cute_dsl with
tensors that may still be on CPU (a, a_global_sf), causing failures; before
calling nvfp4_quantize_cute_dsl ensure the inputs are moved to CUDA (e.g., call
a = a.cuda() and a_global_sf = a_global_sf.cuda() or otherwise validate device)
so device placement matches the CUDA path, and keep the subsequent do_shuffle
steps (shuffle_matrix_a, shuffle_matrix_sf_a) operating on CUDA tensors or move
them back as needed.

In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 448-449: The code currently accepts any sf_layout value and
proceeds down the "swizzled padding" branch but only handles SF_LAYOUT_128x4
when computing offsets; add an explicit validation of the sf_layout parameter
(e.g., allow only SF_LAYOUT_LINEAR and SF_LAYOUT_128x4) at the start of the
routine that computes output sizing/padding so unknown values raise an error
immediately rather than allocating a mismatched buffer; update the same
validation in the other sizing branch referenced around lines 503-514 so both
sizing paths reject invalid sf_layout values before allocating or writing
buffers.

In `@flashinfer/quantization/kernels/nvfp4_quantize.py`:
- Around line 1158-1162: scale_output is being reshaped using logical column
count (num_sf_blocks_per_row) even when the buffer is allocated in a
swizzled/padded physical layout (padded_m, padded_sf_cols); this causes errors
and wrong ordering. Fix by detecting when the buffer is in physical/swizzled
layout (sf_layout != SF_LAYOUT_LINEAR or padded_sf_cols !=
num_sf_blocks_per_row) and either (a) perform an explicit unswizzle + remove
padding to produce a contiguous logical buffer sized m * num_sf_blocks_per_row
before calling scale_output.reshape(-1, num_sf_blocks_per_row), or (b) if you
intend to expose the physical buffer, return it with its physical dimensions
(padded_m, padded_sf_cols) instead of reshaping; apply the same change to the
other occurrence around lines 1195-1199 so all return paths handle
swizzling/padding consistently (use the symbols scale_output, sf_layout,
SF_LAYOUT_LINEAR, m, num_sf_blocks_per_row, padded_m, padded_sf_cols to locate
the code).
- Around line 1097-1104: The code handling global_scale in nvfp4_quantize
creates global_scale_tensor but skips moving CUDA scalars from other devices to
input.device; ensure global_scale_tensor is always materialized on input.device
by converting to float, reshaping/contiguous, and then unconditionally calling
.to(input.device) (or .to(device) with appropriate dtype) when global_scale is a
torch.Tensor so that the tensor used by the kernel (global_scale_tensor) is
guaranteed to live on input.device and avoids cross-device launch failures;
update the logic around global_scale/global_scale_tensor to always perform the
device transfer before passing into the kernel.
- Around line 967-976: The _should_use_tma predicate currently uses
floor(log2(m)) + floor(log2(k)) via bit_length to decide the crossover; change
it to the documented M*K cutoff by replacing the final return with a check that
m * k >= 1 << _TMA_LOG2_MK_THRESHOLD so the TMA kernel is dispatched exactly
when M*K meets the threshold; keep the existing dtype check for
torch.float8_e4m3fn and the early returns for k % _TMA_COLS_PER_STAGE and m <
_TMA_MIN_M intact and update the return statement in _should_use_tma
accordingly.

In `@flashinfer/quantization/quantization_cute_dsl_utils.py`:
- Around line 141-175: The float_to_ue8m0_fast path treats subnormals as having
a bump (exp_biased==0 && mantissa!=0) which yields 1; change the ASM to detect
subnormals and suppress that bump so subnormals become zero: after computing
exp_biased and mantissa, set a predicate for subnormal (exp_biased == 0) and
combine it with p_has_mant to form p_subnormal_has_mant, then only use a bump
when mantissa != 0 AND NOT subnormal (i.e., replace the current selp for bump
that uses p_has_mant with one that uses p_has_mant && !p_subnormal); ensure
subsequent result/clamp logic uses that adjusted bump so true IEEE subnormals
map to 0. Reference: float_to_ue8m0_fast, variables exp_biased, mantissa,
p_has_mant, bump, result.

---

Nitpick comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 742-750: The parameter name `input` in function
_fp4_quantize_cute_dsl shadows the Python builtin; rename the parameter to `x`
(update the signature type hint from `input: torch.Tensor` to `x: torch.Tensor`)
and update every reference inside the function body (all reads/writes, any
slices, clones, and variable assignments that use `input`) to use `x` instead so
the builtin is not shadowed and behavior remains identical.

In `@tests/utils/test_fp4_quantize.py`:
- Around line 136-138: Move the helper function _is_cute_dsl_available() so it
appears before its first call in tests/utils/test_fp4_quantize.py; specifically,
relocate the function definition (currently at lines ~353-360) up into the
helper section near _is_fp4_supported() (around line ~27) so that the call
inside the backend check (if backend == "cute-dsl": if not
_is_cute_dsl_available(): pytest.skip(...)) references a previously defined
function, preserving existing name/signature and any imports used by
_is_cute_dsl_available().
- Around line 666-672: The test uses torch.testing.assert_close with rtol=0 and
atol=0 comparing dq_cuda and dq_cute which enforces bit-exact equality and may
cause flaky failures; update the assertion in tests/utils/test_fp4_quantize.py
to allow a small numeric tolerance (e.g., set a non-zero rtol and/or atol) or
add a clear comment documenting why bit-exact parity is required; specifically
modify the assertion call for torch.testing.assert_close(dq_cuda, dq_cute,
rtol=..., atol=..., msg=error_msg) to use appropriate tolerances or add an
explanatory comment near dq_cuda/dq_cute/error_msg explaining the rationale.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 150c0967-33d6-405d-a654-5bb968d4a14a

📥 Commits

Reviewing files that changed from the base of the PR and between 7cb016d and de98b01.

📒 Files selected for processing (10)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/quantization.py
  • flashinfer/cute_dsl/fp4_common.py
  • flashinfer/quantization/__init__.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/quantization/kernels/__init__.py
  • flashinfer/quantization/kernels/mxfp4_quantize.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py
  • flashinfer/quantization/quantization_cute_dsl_utils.py
  • tests/utils/test_fp4_quantize.py

Comment on lines +967 to +976
def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool:
"""Determine if TMA kernel should be used based on problem dimensions."""
if dtype == torch.float8_e4m3fn:
return False
if k % _TMA_COLS_PER_STAGE != 0:
return False
if m < _TMA_MIN_M:
return False
# Use log2(M) + log2(K) threshold for the crossover point
return m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the documented M*K cutoff for TMA dispatch.

This predicate uses floor(log2(M)) + floor(log2(K)), so many rectangular cases above the intended 2^25 threshold still fall back to the vector-load kernel. m * k >= 1 << _TMA_LOG2_MK_THRESHOLD matches the stated dispatch rule exactly.

Suggested fix
 def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool:
     """Determine if TMA kernel should be used based on problem dimensions."""
     if dtype == torch.float8_e4m3fn:
         return False
     if k % _TMA_COLS_PER_STAGE != 0:
         return False
     if m < _TMA_MIN_M:
         return False
-    # Use log2(M) + log2(K) threshold for the crossover point
-    return m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD
+    return m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool:
"""Determine if TMA kernel should be used based on problem dimensions."""
if dtype == torch.float8_e4m3fn:
return False
if k % _TMA_COLS_PER_STAGE != 0:
return False
if m < _TMA_MIN_M:
return False
# Use log2(M) + log2(K) threshold for the crossover point
return m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD
def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool:
"""Determine if TMA kernel should be used based on problem dimensions."""
if dtype == torch.float8_e4m3fn:
return False
if k % _TMA_COLS_PER_STAGE != 0:
return False
if m < _TMA_MIN_M:
return False
return m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/nvfp4_quantize.py` around lines 967 - 976,
The _should_use_tma predicate currently uses floor(log2(m)) + floor(log2(k)) via
bit_length to decide the crossover; change it to the documented M*K cutoff by
replacing the final return with a check that m * k >= 1 <<
_TMA_LOG2_MK_THRESHOLD so the TMA kernel is dispatched exactly when M*K meets
the threshold; keep the existing dtype check for torch.float8_e4m3fn and the
early returns for k % _TMA_COLS_PER_STAGE and m < _TMA_MIN_M intact and update
the return statement in _should_use_tma accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation should be updated to be more clear in this case. Will do so in next commit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bkryu, understood — a clarifying comment that the floor(log₂(M)) + floor(log₂(K)) predicate is an intentional approximation of the M×K ≥ 2²⁵ threshold (not an accidental one) would make the dispatch logic self-documenting. Thanks for the follow-up!


✏️ Learnings added
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:19.372Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold — not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh:81-81
Timestamp: 2026-03-23T18:17:03.882Z
Learning: In flashinfer-ai/flashinfer selective_state_update (csrc/selective_state_update.cu), the public API intentionally broadcasts:
- D over dim (D.stride(1)==0),
- A over dim and dstate (A.stride(1)==0 and A.stride(2)==0),
- dt_bias over dim (bias.stride(1)==0),
- dt over dim in both STP (dt.stride(2)==0) and MTP (dt.stride(3)==0, or dt.stride(2)==0 in varlen).
Horizontal MTP kernel (include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh) correctly loads A/D/dt_bias per head and dt per (head, step) and stores dt in smem as [HEADS_PER_CTA][TOKENS_MTP].

Learnt from: xrq-phys
Repo: flashinfer-ai/flashinfer PR: 2711
File: csrc/trtllm_fmha_kernel_launcher.cu:552-563
Timestamp: 2026-03-07T06:34:53.719Z
Learning: In `csrc/trtllm_fmha_kernel_launcher.cu` (flashinfer-ai/flashinfer), dtype validation for SageAttention scaling-factor tensors (`sage_attn_sfs_q/k/p/v`) is intentionally absent. This file is a TVM FFI path (not a PyTorch extension path), and dtype validation is expected to be handled at a different layer/entry point. Do not flag missing `TVM_FFI_ICHECK_EQ(...dtype(), dl_float32)` checks for these tensors in this file.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-03-04T05:20:26.963Z
Learning: Keep documentation in sync with code changes, particularly CLAUDE.md and `.claude/skills/` when modifying infrastructure changes, patterns, new conventions, or deprecations

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2635
File: benchmarks/routines/moe.py:547-551
Timestamp: 2026-02-25T00:56:38.933Z
Learning: In FlashInfer's quantization code, `torch.float8_e4m3fn` is used as a "carrier dtype" for 1-byte scale factors (UE8M0, etc.) because PyTorch lacks native support for these formats. The pattern `scale_uint8.view(torch.float8_e4m3fn)` is standard throughout the codebase - it's not a semantic conversion but a dtype relabeling, and the C++ kernels interpret the raw bytes correctly regardless of the PyTorch dtype label. This applies to both FP4 and MXFP8 quantization scales.

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-03-04T05:20:26.963Z
Learning: Applies to include/flashinfer/**/*.cuh : Write kernel implementations in `include/flashinfer/` using framework-agnostic CUDA code that accepts raw pointers

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/quantization/fp4_quantization.py (1)

776-778: Consider adding .cuda() for consistency with nvfp4_quantize.

The _fp4_quantize_cute_dsl helper passes input and global_scale directly to the CuTe-DSL kernels without ensuring they're on CUDA. While the underlying nvfp4_quantize_cute_dsl has an assert input.is_cuda check, adding explicit .cuda() calls here would:

  1. Match the pattern in nvfp4_quantize (lines 1003-1004)
  2. Provide clearer error messages for CPU tensor inputs
Suggested fix
-        return nvfp4_quantize_cute_dsl(
-            input, global_scale, sf_layout=sf_layout, enable_pdl=enable_pdl
+        return nvfp4_quantize_cute_dsl(
+            input.cuda(), global_scale.cuda() if global_scale is not None else None,
+            sf_layout=sf_layout, enable_pdl=enable_pdl
         )

Similar change needed for the MXFP4 path at line 794-796.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 776 - 778, The
helper _fp4_quantize_cute_dsl currently forwards input and global_scale to
nvfp4_quantize_cute_dsl without ensuring CUDA tensors; update
_fp4_quantize_cute_dsl to call .cuda() on both input and global_scale before
passing them into nvfp4_quantize_cute_dsl (mirroring the pattern in
nvfp4_quantize) and apply the same change for the MXFP4 path helper that calls
nvfp4_quantize_cute_dsl so CPU tensors are moved to CUDA and produce clearer
errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 533-536: The reshape of scale_output after kernel_fn can mismatch
for swizzled layouts because scale_output is allocated with padded_m *
padded_sf_cols elements but is reshaped using num_sf_blocks_per_row; update the
reshape to use padded_sf_cols (keeping physical layout) or explicitly trim the
swizzled buffer to padded_m * num_sf_blocks_per_row before reshaping so the
trailing dim equals num_sf_blocks_per_row; locate usage around kernel_fn,
scale_output, num_sf_blocks_per_row, padded_sf_cols and padded_m and apply the
appropriate change (use padded_sf_cols in reshape or slice scale_output to
remove padding first).

---

Nitpick comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 776-778: The helper _fp4_quantize_cute_dsl currently forwards
input and global_scale to nvfp4_quantize_cute_dsl without ensuring CUDA tensors;
update _fp4_quantize_cute_dsl to call .cuda() on both input and global_scale
before passing them into nvfp4_quantize_cute_dsl (mirroring the pattern in
nvfp4_quantize) and apply the same change for the MXFP4 path helper that calls
nvfp4_quantize_cute_dsl so CPU tensors are moved to CUDA and produce clearer
errors.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bf77357d-39fa-48d9-badd-a2921872a820

📥 Commits

Reviewing files that changed from the base of the PR and between de98b01 and 6c9d24e.

📒 Files selected for processing (5)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/quantization.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/quantization/kernels/mxfp4_quantize.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/quantization.py

Comment on lines 533 to 536
kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks)

# Reshape scale output to match CUDA backend format: [padded_total, num_sf_per_row]
scale_output = scale_output.reshape(-1, num_sf_blocks_per_row)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential reshape mismatch for swizzled layouts.

For swizzled layouts, scale_output is allocated with padded_m * padded_sf_cols elements where padded_sf_cols may be larger than num_sf_blocks_per_row (padded to multiple of 4). The reshape uses num_sf_blocks_per_row as the trailing dimension, which could fail or produce incorrect results when num_sf_blocks_per_row is not a multiple of 4.

Consider using padded_sf_cols for the reshape dimension (consistent with how NVFP4 handles this):

-    scale_output = scale_output.reshape(-1, num_sf_blocks_per_row)
+    scale_output = scale_output.reshape(-1, padded_sf_cols)

Alternatively, if the intent is to expose logical dimensions, trim the swizzled buffer first or document that callers must interpret the physical layout.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/mxfp4_quantize.py` around lines 533 - 536,
The reshape of scale_output after kernel_fn can mismatch for swizzled layouts
because scale_output is allocated with padded_m * padded_sf_cols elements but is
reshaped using num_sf_blocks_per_row; update the reshape to use padded_sf_cols
(keeping physical layout) or explicitly trim the swizzled buffer to padded_m *
num_sf_blocks_per_row before reshaping so the trailing dim equals
num_sf_blocks_per_row; locate usage around kernel_fn, scale_output,
num_sf_blocks_per_row, padded_sf_cols and padded_m and apply the appropriate
change (use padded_sf_cols in reshape or slice scale_output to remove padding
first).

@bkryu
Copy link
Collaborator Author

bkryu commented Mar 23, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !447 has been created, and the CI pipeline #46810194 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46810194: 13/20 passed

@bkryu bkryu self-assigned this Mar 24, 2026
Copy link
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bkryu bkryu enabled auto-merge (squash) March 24, 2026 23:26
@bkryu bkryu requested a review from samuellees as a code owner March 25, 2026 05:08
@bkryu bkryu merged commit d426b18 into flashinfer-ai:main Mar 26, 2026
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants