feat: Add CuTe-DSL backend for NVFP4 quantization#2838
feat: Add CuTe-DSL backend for NVFP4 quantization#2838bkryu merged 14 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds CuTe‑DSL as a second backend for NVFP4 quantization, implements new CuTe‑DSL NVFP4 kernels and FP4/CUTE helpers, extends fp4_quantize/nvfp4_quantize APIs with a backend parameter and dispatch, updates MXFP4 kernel layout handling, and expands tests and benchmark mappings for multi‑backend validation. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test/Client
participant FP4API as fp4_quantize API
participant NVFPDispatch as nvfp4_quantize (dispatcher)
participant CudaKernel as CUDA implementation
participant CuteDslKernel as CuTe‑DSL implementation
Test->>FP4API: fp4_quantize(input,..., backend="cuda"|"cute-dsl")
FP4API->>NVFPDispatch: nvfp4_quantize(..., backend)
alt backend == "cuda"
NVFPDispatch->>CudaKernel: execute CUDA path
CudaKernel-->>NVFPDispatch: (fp4_output, scale)
else backend == "cute-dsl"
NVFPDispatch->>NVFPDispatch: validate CuTe‑DSL, map sf_layout
NVFPDispatch->>CuteDslKernel: select & invoke (Swizzled|TMA)
CuteDslKernel-->>NVFPDispatch: (fp4_output, scale)
end
NVFPDispatch-->>FP4API: return (fp4_output, scale)
FP4API-->>Test: result
sequenceDiagram
participant Kernel as NVFP4QuantizeSwizzledKernel
participant Scale as Scale computation
participant Output as Output writer
Kernel->>Kernel: Partition into 16‑elem blocks
loop per block
Kernel->>Scale: compute block max / derive E2M1 scale
Scale->>Kernel: normalized scale, pack to E2M1
Kernel->>Output: write packed FP4 bytes and scale factor at layout offset
end
Output-->>Kernel: quantized outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a highly optimized CuTe-DSL backend for NVFP4 quantization, aiming to boost performance and expand functionality. It incorporates advanced GPU programming techniques like TMA-based memory access and intelligent kernel dispatching to achieve substantial speedups across a wide range of problem sizes. The changes also enhance the flexibility of existing MXFP4 kernels and ensure robust operation through comprehensive testing. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new, high-performance CuTe-DSL backend for NVFP4 quantization, complete with two kernel variants (default and TMA-based) and automatic dispatching. The implementation is comprehensive, adding new low-level PTX wrappers, new kernels, and refactoring existing ones for better generality. The performance gains are substantial. The accompanying tests are thorough, ensuring correctness and parity with the existing CUDA backend. I've identified a potential bug in the scale factor layout handling when shuffling is enabled for the new backend, along with an opportunity for refactoring to reduce code duplication.
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (3)
tests/utils/test_fp4_quantize.py (2)
136-138: Move_is_cute_dsl_available()definition before first usage.The helper function
_is_cute_dsl_available()is called here but defined later at line 353. While Python resolves this at runtime, it harms readability. Consider moving the definition (lines 353-360) before its first usage, perhaps near the other helper functions like_is_fp4_supported()at line 27.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_fp4_quantize.py` around lines 136 - 138, Move the helper function _is_cute_dsl_available() so it appears before its first call in tests/utils/test_fp4_quantize.py; specifically, relocate the function definition (currently at lines ~353-360) up into the helper section near _is_fp4_supported() (around line ~27) so that the call inside the backend check (if backend == "cute-dsl": if not _is_cute_dsl_available(): pytest.skip(...)) references a previously defined function, preserving existing name/signature and any imports used by _is_cute_dsl_available().
666-672: Same exact-equality concern as MXFP4 parity test.Consider aligning with a small tolerance if flaky failures occur, or document why bit-exact parity is required here.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_fp4_quantize.py` around lines 666 - 672, The test uses torch.testing.assert_close with rtol=0 and atol=0 comparing dq_cuda and dq_cute which enforces bit-exact equality and may cause flaky failures; update the assertion in tests/utils/test_fp4_quantize.py to allow a small numeric tolerance (e.g., set a non-zero rtol and/or atol) or add a clear comment documenting why bit-exact parity is required; specifically modify the assertion call for torch.testing.assert_close(dq_cuda, dq_cute, rtol=..., atol=..., msg=error_msg) to use appropriate tolerances or add an explanatory comment near dq_cuda/dq_cute/error_msg explaining the rationale.flashinfer/quantization/fp4_quantization.py (1)
742-750: Parameter nameinputshadows Python builtin.The parameter
inputshadows the Python builtin function. While functional, this can cause subtle issues if the builtin is needed within the function.Suggested rename
def _fp4_quantize_cute_dsl( - input: torch.Tensor, + x: torch.Tensor, global_scale: Optional[torch.Tensor], sf_vec_size: int, sf_use_ue8m0: bool, is_sf_swizzled_layout: bool, is_sf_8x4_layout: bool, enable_pdl: Optional[bool], ) -> Tuple[torch.Tensor, torch.Tensor]:(And update all references to
inputwithin the function tox)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/fp4_quantization.py` around lines 742 - 750, The parameter name `input` in function _fp4_quantize_cute_dsl shadows the Python builtin; rename the parameter to `x` (update the signature type hint from `input: torch.Tensor` to `x: torch.Tensor`) and update every reference inside the function body (all reads/writes, any slices, clones, and variable assignments that use `input`) to use `x` instead so the builtin is not shadowed and behavior remains identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 991-1002: The CuTe-DSL branch calls nvfp4_quantize_cute_dsl with
tensors that may still be on CPU (a, a_global_sf), causing failures; before
calling nvfp4_quantize_cute_dsl ensure the inputs are moved to CUDA (e.g., call
a = a.cuda() and a_global_sf = a_global_sf.cuda() or otherwise validate device)
so device placement matches the CUDA path, and keep the subsequent do_shuffle
steps (shuffle_matrix_a, shuffle_matrix_sf_a) operating on CUDA tensors or move
them back as needed.
In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 448-449: The code currently accepts any sf_layout value and
proceeds down the "swizzled padding" branch but only handles SF_LAYOUT_128x4
when computing offsets; add an explicit validation of the sf_layout parameter
(e.g., allow only SF_LAYOUT_LINEAR and SF_LAYOUT_128x4) at the start of the
routine that computes output sizing/padding so unknown values raise an error
immediately rather than allocating a mismatched buffer; update the same
validation in the other sizing branch referenced around lines 503-514 so both
sizing paths reject invalid sf_layout values before allocating or writing
buffers.
In `@flashinfer/quantization/kernels/nvfp4_quantize.py`:
- Around line 1158-1162: scale_output is being reshaped using logical column
count (num_sf_blocks_per_row) even when the buffer is allocated in a
swizzled/padded physical layout (padded_m, padded_sf_cols); this causes errors
and wrong ordering. Fix by detecting when the buffer is in physical/swizzled
layout (sf_layout != SF_LAYOUT_LINEAR or padded_sf_cols !=
num_sf_blocks_per_row) and either (a) perform an explicit unswizzle + remove
padding to produce a contiguous logical buffer sized m * num_sf_blocks_per_row
before calling scale_output.reshape(-1, num_sf_blocks_per_row), or (b) if you
intend to expose the physical buffer, return it with its physical dimensions
(padded_m, padded_sf_cols) instead of reshaping; apply the same change to the
other occurrence around lines 1195-1199 so all return paths handle
swizzling/padding consistently (use the symbols scale_output, sf_layout,
SF_LAYOUT_LINEAR, m, num_sf_blocks_per_row, padded_m, padded_sf_cols to locate
the code).
- Around line 1097-1104: The code handling global_scale in nvfp4_quantize
creates global_scale_tensor but skips moving CUDA scalars from other devices to
input.device; ensure global_scale_tensor is always materialized on input.device
by converting to float, reshaping/contiguous, and then unconditionally calling
.to(input.device) (or .to(device) with appropriate dtype) when global_scale is a
torch.Tensor so that the tensor used by the kernel (global_scale_tensor) is
guaranteed to live on input.device and avoids cross-device launch failures;
update the logic around global_scale/global_scale_tensor to always perform the
device transfer before passing into the kernel.
- Around line 967-976: The _should_use_tma predicate currently uses
floor(log2(m)) + floor(log2(k)) via bit_length to decide the crossover; change
it to the documented M*K cutoff by replacing the final return with a check that
m * k >= 1 << _TMA_LOG2_MK_THRESHOLD so the TMA kernel is dispatched exactly
when M*K meets the threshold; keep the existing dtype check for
torch.float8_e4m3fn and the early returns for k % _TMA_COLS_PER_STAGE and m <
_TMA_MIN_M intact and update the return statement in _should_use_tma
accordingly.
In `@flashinfer/quantization/quantization_cute_dsl_utils.py`:
- Around line 141-175: The float_to_ue8m0_fast path treats subnormals as having
a bump (exp_biased==0 && mantissa!=0) which yields 1; change the ASM to detect
subnormals and suppress that bump so subnormals become zero: after computing
exp_biased and mantissa, set a predicate for subnormal (exp_biased == 0) and
combine it with p_has_mant to form p_subnormal_has_mant, then only use a bump
when mantissa != 0 AND NOT subnormal (i.e., replace the current selp for bump
that uses p_has_mant with one that uses p_has_mant && !p_subnormal); ensure
subsequent result/clamp logic uses that adjusted bump so true IEEE subnormals
map to 0. Reference: float_to_ue8m0_fast, variables exp_biased, mantissa,
p_has_mant, bump, result.
---
Nitpick comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 742-750: The parameter name `input` in function
_fp4_quantize_cute_dsl shadows the Python builtin; rename the parameter to `x`
(update the signature type hint from `input: torch.Tensor` to `x: torch.Tensor`)
and update every reference inside the function body (all reads/writes, any
slices, clones, and variable assignments that use `input`) to use `x` instead so
the builtin is not shadowed and behavior remains identical.
In `@tests/utils/test_fp4_quantize.py`:
- Around line 136-138: Move the helper function _is_cute_dsl_available() so it
appears before its first call in tests/utils/test_fp4_quantize.py; specifically,
relocate the function definition (currently at lines ~353-360) up into the
helper section near _is_fp4_supported() (around line ~27) so that the call
inside the backend check (if backend == "cute-dsl": if not
_is_cute_dsl_available(): pytest.skip(...)) references a previously defined
function, preserving existing name/signature and any imports used by
_is_cute_dsl_available().
- Around line 666-672: The test uses torch.testing.assert_close with rtol=0 and
atol=0 comparing dq_cuda and dq_cute which enforces bit-exact equality and may
cause flaky failures; update the assertion in tests/utils/test_fp4_quantize.py
to allow a small numeric tolerance (e.g., set a non-zero rtol and/or atol) or
add a clear comment documenting why bit-exact parity is required; specifically
modify the assertion call for torch.testing.assert_close(dq_cuda, dq_cute,
rtol=..., atol=..., msg=error_msg) to use appropriate tolerances or add an
explanatory comment near dq_cuda/dq_cute/error_msg explaining the rationale.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 150c0967-33d6-405d-a654-5bb968d4a14a
📒 Files selected for processing (10)
benchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/quantization.pyflashinfer/cute_dsl/fp4_common.pyflashinfer/quantization/__init__.pyflashinfer/quantization/fp4_quantization.pyflashinfer/quantization/kernels/__init__.pyflashinfer/quantization/kernels/mxfp4_quantize.pyflashinfer/quantization/kernels/nvfp4_quantize.pyflashinfer/quantization/quantization_cute_dsl_utils.pytests/utils/test_fp4_quantize.py
| def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool: | ||
| """Determine if TMA kernel should be used based on problem dimensions.""" | ||
| if dtype == torch.float8_e4m3fn: | ||
| return False | ||
| if k % _TMA_COLS_PER_STAGE != 0: | ||
| return False | ||
| if m < _TMA_MIN_M: | ||
| return False | ||
| # Use log2(M) + log2(K) threshold for the crossover point | ||
| return m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD |
There was a problem hiding this comment.
Use the documented M*K cutoff for TMA dispatch.
This predicate uses floor(log2(M)) + floor(log2(K)), so many rectangular cases above the intended 2^25 threshold still fall back to the vector-load kernel. m * k >= 1 << _TMA_LOG2_MK_THRESHOLD matches the stated dispatch rule exactly.
Suggested fix
def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool:
"""Determine if TMA kernel should be used based on problem dimensions."""
if dtype == torch.float8_e4m3fn:
return False
if k % _TMA_COLS_PER_STAGE != 0:
return False
if m < _TMA_MIN_M:
return False
- # Use log2(M) + log2(K) threshold for the crossover point
- return m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD
+ return m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool: | |
| """Determine if TMA kernel should be used based on problem dimensions.""" | |
| if dtype == torch.float8_e4m3fn: | |
| return False | |
| if k % _TMA_COLS_PER_STAGE != 0: | |
| return False | |
| if m < _TMA_MIN_M: | |
| return False | |
| # Use log2(M) + log2(K) threshold for the crossover point | |
| return m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD | |
| def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool: | |
| """Determine if TMA kernel should be used based on problem dimensions.""" | |
| if dtype == torch.float8_e4m3fn: | |
| return False | |
| if k % _TMA_COLS_PER_STAGE != 0: | |
| return False | |
| if m < _TMA_MIN_M: | |
| return False | |
| return m * k >= (1 << _TMA_LOG2_MK_THRESHOLD) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/quantization/kernels/nvfp4_quantize.py` around lines 967 - 976,
The _should_use_tma predicate currently uses floor(log2(m)) + floor(log2(k)) via
bit_length to decide the crossover; change it to the documented M*K cutoff by
replacing the final return with a check that m * k >= 1 <<
_TMA_LOG2_MK_THRESHOLD so the TMA kernel is dispatched exactly when M*K meets
the threshold; keep the existing dtype check for torch.float8_e4m3fn and the
early returns for k % _TMA_COLS_PER_STAGE and m < _TMA_MIN_M intact and update
the return statement in _should_use_tma accordingly.
There was a problem hiding this comment.
The documentation should be updated to be more clear in this case. Will do so in next commit
There was a problem hiding this comment.
@bkryu, understood — a clarifying comment that the floor(log₂(M)) + floor(log₂(K)) predicate is an intentional approximation of the M×K ≥ 2²⁵ threshold (not an accidental one) would make the dispatch logic self-documenting. Thanks for the follow-up!
✏️ Learnings added
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:19.372Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold — not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh:81-81
Timestamp: 2026-03-23T18:17:03.882Z
Learning: In flashinfer-ai/flashinfer selective_state_update (csrc/selective_state_update.cu), the public API intentionally broadcasts:
- D over dim (D.stride(1)==0),
- A over dim and dstate (A.stride(1)==0 and A.stride(2)==0),
- dt_bias over dim (bias.stride(1)==0),
- dt over dim in both STP (dt.stride(2)==0) and MTP (dt.stride(3)==0, or dt.stride(2)==0 in varlen).
Horizontal MTP kernel (include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh) correctly loads A/D/dt_bias per head and dt per (head, step) and stores dt in smem as [HEADS_PER_CTA][TOKENS_MTP].
Learnt from: xrq-phys
Repo: flashinfer-ai/flashinfer PR: 2711
File: csrc/trtllm_fmha_kernel_launcher.cu:552-563
Timestamp: 2026-03-07T06:34:53.719Z
Learning: In `csrc/trtllm_fmha_kernel_launcher.cu` (flashinfer-ai/flashinfer), dtype validation for SageAttention scaling-factor tensors (`sage_attn_sfs_q/k/p/v`) is intentionally absent. This file is a TVM FFI path (not a PyTorch extension path), and dtype validation is expected to be handled at a different layer/entry point. Do not flag missing `TVM_FFI_ICHECK_EQ(...dtype(), dl_float32)` checks for these tensors in this file.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-03-04T05:20:26.963Z
Learning: Keep documentation in sync with code changes, particularly CLAUDE.md and `.claude/skills/` when modifying infrastructure changes, patterns, new conventions, or deprecations
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2635
File: benchmarks/routines/moe.py:547-551
Timestamp: 2026-02-25T00:56:38.933Z
Learning: In FlashInfer's quantization code, `torch.float8_e4m3fn` is used as a "carrier dtype" for 1-byte scale factors (UE8M0, etc.) because PyTorch lacks native support for these formats. The pattern `scale_uint8.view(torch.float8_e4m3fn)` is standard throughout the codebase - it's not a semantic conversion but a dtype relabeling, and the C++ kernels interpret the raw bytes correctly regardless of the PyTorch dtype label. This applies to both FP4 and MXFP8 quantization scales.
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-03-04T05:20:26.963Z
Learning: Applies to include/flashinfer/**/*.cuh : Write kernel implementations in `include/flashinfer/` using framework-agnostic CUDA code that accepts raw pointers
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/quantization/fp4_quantization.py (1)
776-778: Consider adding.cuda()for consistency with nvfp4_quantize.The
_fp4_quantize_cute_dslhelper passesinputandglobal_scaledirectly to the CuTe-DSL kernels without ensuring they're on CUDA. While the underlyingnvfp4_quantize_cute_dslhas anassert input.is_cudacheck, adding explicit.cuda()calls here would:
- Match the pattern in
nvfp4_quantize(lines 1003-1004)- Provide clearer error messages for CPU tensor inputs
Suggested fix
- return nvfp4_quantize_cute_dsl( - input, global_scale, sf_layout=sf_layout, enable_pdl=enable_pdl + return nvfp4_quantize_cute_dsl( + input.cuda(), global_scale.cuda() if global_scale is not None else None, + sf_layout=sf_layout, enable_pdl=enable_pdl )Similar change needed for the MXFP4 path at line 794-796.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/fp4_quantization.py` around lines 776 - 778, The helper _fp4_quantize_cute_dsl currently forwards input and global_scale to nvfp4_quantize_cute_dsl without ensuring CUDA tensors; update _fp4_quantize_cute_dsl to call .cuda() on both input and global_scale before passing them into nvfp4_quantize_cute_dsl (mirroring the pattern in nvfp4_quantize) and apply the same change for the MXFP4 path helper that calls nvfp4_quantize_cute_dsl so CPU tensors are moved to CUDA and produce clearer errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 533-536: The reshape of scale_output after kernel_fn can mismatch
for swizzled layouts because scale_output is allocated with padded_m *
padded_sf_cols elements but is reshaped using num_sf_blocks_per_row; update the
reshape to use padded_sf_cols (keeping physical layout) or explicitly trim the
swizzled buffer to padded_m * num_sf_blocks_per_row before reshaping so the
trailing dim equals num_sf_blocks_per_row; locate usage around kernel_fn,
scale_output, num_sf_blocks_per_row, padded_sf_cols and padded_m and apply the
appropriate change (use padded_sf_cols in reshape or slice scale_output to
remove padding first).
---
Nitpick comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 776-778: The helper _fp4_quantize_cute_dsl currently forwards
input and global_scale to nvfp4_quantize_cute_dsl without ensuring CUDA tensors;
update _fp4_quantize_cute_dsl to call .cuda() on both input and global_scale
before passing them into nvfp4_quantize_cute_dsl (mirroring the pattern in
nvfp4_quantize) and apply the same change for the MXFP4 path helper that calls
nvfp4_quantize_cute_dsl so CPU tensors are moved to CUDA and produce clearer
errors.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bf77357d-39fa-48d9-badd-a2921872a820
📒 Files selected for processing (5)
benchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/quantization.pyflashinfer/quantization/fp4_quantization.pyflashinfer/quantization/kernels/mxfp4_quantize.pyflashinfer/quantization/kernels/nvfp4_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
- benchmarks/routines/flashinfer_benchmark_utils.py
- benchmarks/routines/quantization.py
| kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks) | ||
|
|
||
| # Reshape scale output to match CUDA backend format: [padded_total, num_sf_per_row] | ||
| scale_output = scale_output.reshape(-1, num_sf_blocks_per_row) | ||
|
|
There was a problem hiding this comment.
Potential reshape mismatch for swizzled layouts.
For swizzled layouts, scale_output is allocated with padded_m * padded_sf_cols elements where padded_sf_cols may be larger than num_sf_blocks_per_row (padded to multiple of 4). The reshape uses num_sf_blocks_per_row as the trailing dimension, which could fail or produce incorrect results when num_sf_blocks_per_row is not a multiple of 4.
Consider using padded_sf_cols for the reshape dimension (consistent with how NVFP4 handles this):
- scale_output = scale_output.reshape(-1, num_sf_blocks_per_row)
+ scale_output = scale_output.reshape(-1, padded_sf_cols)Alternatively, if the intent is to expose logical dimensions, trim the swizzled buffer first or document that callers must interpret the physical layout.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/quantization/kernels/mxfp4_quantize.py` around lines 533 - 536,
The reshape of scale_output after kernel_fn can mismatch for swizzled layouts
because scale_output is allocated with padded_m * padded_sf_cols elements but is
reshaped using num_sf_blocks_per_row; update the reshape to use padded_sf_cols
(keeping physical layout) or explicitly trim the swizzled buffer to padded_m *
num_sf_blocks_per_row before reshaping so the trailing dim equals
num_sf_blocks_per_row; locate usage around kernel_fn, scale_output,
num_sf_blocks_per_row, padded_sf_cols and padded_m and apply the appropriate
change (use padded_sf_cols in reshape or slice scale_output to remove padding
first).
|
/bot run |
|
[FAILED] Pipeline #46810194: 13/20 passed |
📌 Description
Performance
Performance Sweeps in M (rows) and K (cols) space comparing the two backends:
B200:
RTX PRO 6000 Workstation:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests