Skip to content

perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels#2904

Open
bkryu wants to merge 8 commits intoflashinfer-ai:mainfrom
bkryu:optimize_quant
Open

perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels#2904
bkryu wants to merge 8 commits intoflashinfer-ai:mainfrom
bkryu:optimize_quant

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Mar 27, 2026

📌 Description

Summary

  • Adopt dual-path kernel architecture (linear flat + swizzled row-based) for MXFP4 and NVFP4 CuTe-DSL quantization kernels.
  • Architecture chanes to MXFP8 quantization for better perf.
  • Expand benchmark scripts and test coverage across all three quantization kernels. Compares exact output match between CUDA & CuTe-DSL backends
  • All mxfp4, mxfp8, and nvfp4 quantization have exact bitwise match in for CUDA vs. CuTe DSL backends in both the output and scaling factors.

Kernel changes
mxfp8_quantize.py

  • Adaptive 2T/SF dispatch: 2 threads per SF block for large problems (total_sf >= 65536), 4 threads for small problems for better memory bandwidth utilization.
  • Integer UE8M0 conversion (float_to_ue8m0_fast, ue8m0_to_inv_scale_fast): replaces SFU-based lg2.approx/ex2.approx with integer bit manipulation, freeing the SFU pipeline
  • reduce_max_2threads: 1-shuffle XOR reduction for the 2T path
  • Remove unused self.dtype and self.K attributes

mxfp4_quantize.py

  • Add swizzled kernel. Previously only supported linear layout.
  • Swizzled kernel: small-K multi-row path and large-K column-loop path, compile-time selected via const_expr(needs_col_loop)
  • Inline padding for swizzled layout (row and column) — eliminates the expensive separate flat-iteration padding passes that caused 5x+ regression at small M
  • Dynamic thread count via _compute_optimal_threads(K) for 100% thread utilization

nvfp4_quantize.py

  • Same dual-path split: NVFP4QuantizeLinearKernel + NVFP4QuantizeSwizzledKernel
  • Supports all three SF layouts (128x4, 8x4, linear) with compile-time dispatch
  • Remove unused self.row_tile_size and self.ROW_ITERATIONS from TMA kernel

quantization_cute_dsl_utils.py

  • ue8m0_to_inv_scale_fast: integer bit construction replacing ex2.approx
  • reduce_max_2threads: 1-shuffle reduction for 2T/SF MXFP8 path
  • 2T/SF constants: ELTS_PER_THREAD, THREADS_PER_SF, SF_BLOCKS_PER_WARP + legacy 4T variants
  • MXFP8_2T_SF_THRESHOLD = 65536

Test changes

  • test_fp4_quantize.py and test_fp8_quantize.py: Add more problem sizes.
  • test_fp4_quantize_padding.py: Add both-backend parametrization and CUDA-vs-CuTe-DSL parity test for linear layout padding.

Perf comparison between backends on B200

Click to see mxfp8 performance comparison

Linear (gmean 1.42x)

mxfp8_backend_comparison_linear_bfloat16

Swizzled (gmean 1.37x)

mxfp8_backend_comparison_swizzled_bfloat16
Click to see mxfp4 performance comparison

Linear (gmean 1.41x)

mxfp4_quantize_backend_comparison_linear_bfloat16

Swizzled (gmean 1.39x)

mxfp4_quantize_backend_comparison_swizzled_bfloat16
Click to see nvfp4 performance comparison

Linear (gmean 1.34x)

nvfp4_quantize_backend_comparison_linear_bfloat16

Swizzled (gmean 1.32x)

nvfp4_quantize_backend_comparison_swizzled_bfloat16

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added linear scale-factor layout support across MXFP4/MXFP8/NVFP4 and updated benchmarks to run/report linear vs swizzled separately, producing distinct heatmaps/tables.
    • Added a new NVFP4 quantization benchmark script with bandwidth and comparison modes.
  • Refactor

    • Split quantize kernels into layout-specific implementations and introduced MXFP8 dual-mode (2T/4T per SF) threading optimizations.
  • Tests

    • Expanded test parameter sweeps, added backend parameterization and capability-aware skips.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 27, 2026

📝 Walkthrough

Walkthrough

Adds layout-specific MXFP4/MXFP8/NVFP4 CuTe-DSL quantize kernels (linear vs swizzled), threads/layout selection and compile-time modes, updates quantization utilities, expands benchmarks to run both layouts (bandwidth and comparison modes), and broadens tests to cover new shapes and backend gating.

Changes

Cohort / File(s) Summary
Benchmarks — MXFP4
benchmarks/bench_mxfp4_quantize_backend_comparison.py
Threaded is_sf_swizzled_layout through correctness and timing flows; switched to flashinfer.quantization.fp4_quantization.fp4_quantize APIs; compute per-run global_sf; run separate linear/swizzled sweeps; configurable layout_name for reporting and heatmaps.
Benchmarks — MXFP8
benchmarks/bench_mxfp8_quantize_backend_comparison.py
Added verify_mxfp8_correctness(...) performing quant/scale agreement and cosine-similarity checks; run per-(m,k) verification and skip failed cases; expanded small-M sweep values and improved console status output.
Benchmarks — NVFP4 (new)
benchmarks/bench_nvfp4_quantize_backend_comparison.py
New script implementing NVFP4 linear/swizzled benchmarks and correctness checks, bandwidth-mode measurement, SM/CuTe-DSL gating, time/bandwidth sweeps, and heatmap/table generation.
Kernel exports
flashinfer/quantization/kernels/__init__.py
Replaced single MXFP4QuantizeKernel export with MXFP4QuantizeLinearKernel and MXFP4QuantizeSwizzledKernel in public exports.
MXFP4 kernels
flashinfer/quantization/kernels/mxfp4_quantize.py
Split unified MXFP4 kernel into MXFP4QuantizeLinearKernel and MXFP4QuantizeSwizzledKernel; removed dual-path runtime dispatch; added layout-specific thread/block strategies, compile-time cache keys, and different launch metadata.
MXFP8 kernels
flashinfer/quantization/kernels/mxfp8_quantize.py
Introduced use_2t_per_sf compile-time mode (2T/SF vs 4T/SF), recomputed thread-to-element mapping, added 2-thread reduction path and dual-load/store path, and threaded mode into cache keys and constructors.
NVFP4 kernels
flashinfer/quantization/kernels/nvfp4_quantize.py
Added NVFP4QuantizeLinearKernel, refactored swizzled kernel thread planning/control flow, and made _get_compiled_kernel_nvfp4 return layout-specific kernel + block-unit metadata.
Quantization utils
flashinfer/quantization/quantization_cute_dsl_utils.py
Changed MXFP8 defaults to 2T/SF (16 elts/thread) while keeping legacy small-problem constants; replaced ex2.approx path in ue8m0_to_inv_scale_fast with integer float construction; added reduce_max_2threads and exported new constants.
Tests — FP4/FP8
tests/utils/test_fp4_quantize.py, tests/utils/test_fp4_quantize_padding.py, tests/utils/test_fp8_quantize.py
Expanded MXFP4/NVFP4 shape lists to include very small/odd/large-K cases; added backend param and CuTe-DSL availability gating in padding tests; updated MXFP8 tests to new compiled-kernel helper names and enlarged m/k parameter grids.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

cute-dsl, run-ci

Suggested reviewers

  • yzh119
  • aleozlx
  • jimmyzho
  • cyx-6
  • yongwww

Poem

🐰 I hopped through kernels, split the paths today,

Linear rows and swizzled hops now both can play.
Benchmarks hum and heatmaps glow, tests bound in a sprint,
I nibble code and carrot bytes — two layouts in a sprint.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 69.12% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: optimizing CuTe-DSL quantization kernels for fp4 and fp8 formats, which is the primary focus of this changeset.
Description check ✅ Passed PR description is comprehensive and covers kernel changes, architectural improvements, test and benchmark updates, and performance results with heatmaps.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu changed the title perf: Optimize CuTe-DSL quantization kernels perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels Mar 27, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations and structural improvements to the MXFP4, MXFP8, and NVFP4 quantization kernels in the CuTe-DSL backend. Key changes include the implementation of a dual-path optimization strategy (linear vs. swizzled layouts) to improve thread utilization, the addition of correctness verification in benchmark suites, and the introduction of a new NVFP4 benchmark. I have identified an inconsistency in the run_benchmark_sweep function docstring where a no_verify parameter is documented but missing from the function signature.

k_values: List of K dimensions to benchmark
dtype: Input dtype
is_sf_swizzled_layout: Whether to use swizzled scale factor layout
no_verify: Skip correctness verification for pure timing runs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The no_verify parameter is mentioned in the docstring but is not present in the function signature. This creates an inconsistency between the documentation and the actual function API.

Suggested change
no_verify: Skip correctness verification for pure timing runs
is_sf_swizzled_layout: Whether to use swizzled scale factor layout

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. deleting in followup commit

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/quantization/kernels/__init__.py (1)

44-52: ⚠️ Potential issue | 🟡 Minor

Re-export NVFP4QuantizeLinearKernel here.

flashinfer.quantization.kernels.nvfp4_quantize now publishes NVFP4QuantizeLinearKernel, but this package surface still omits it. That leaves from flashinfer.quantization.kernels import NVFP4QuantizeLinearKernel broken even though the layout split is now public.

🔧 Proposed fix
 from .nvfp4_quantize import (
+    NVFP4QuantizeLinearKernel,
     NVFP4QuantizeSwizzledKernel,
     nvfp4_quantize_cute_dsl,
 )
@@
+    "NVFP4QuantizeLinearKernel",
     "NVFP4QuantizeSwizzledKernel",
     "nvfp4_quantize_cute_dsl",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/__init__.py` around lines 44 - 52, The
package __all__ list is missing NVFP4QuantizeLinearKernel which prevents
re-exporting it; update the __all__ in the kernels package to include
"NVFP4QuantizeLinearKernel" alongside the other symbols (e.g., add
"NVFP4QuantizeLinearKernel" to the __all__ list that currently contains
"NVFP4QuantizeSwizzledKernel", "nvfp4_quantize_cute_dsl", etc.) so that from
flashinfer.quantization.kernels import NVFP4QuantizeLinearKernel works as
expected.
benchmarks/bench_mxfp4_quantize_backend_comparison.py (2)

223-257: ⚠️ Potential issue | 🟠 Major

Count swizzled padding in the bandwidth numerator.

This helper always treats scale-factor traffic as m * k / 32, but the swizzled MXFP4 path writes padded_m * padded_sf_cols bytes in flashinfer/quantization/kernels/mxfp4_quantize.py Lines 668-670. The reported TB/s is therefore inflated, especially for small M.

📏 Proposed fix
-def compute_bandwidth_tb_per_sec(
-    m: int, k: int, dtype: torch.dtype, time_ms: float
+def compute_bandwidth_tb_per_sec(
+    m: int,
+    k: int,
+    dtype: torch.dtype,
+    time_ms: float,
+    is_sf_swizzled_layout: bool,
 ) -> float:
@@
-    num_scale_factors = num_elements // SF_VEC_SIZE
+    if is_sf_swizzled_layout:
+        padded_m = ((m + 128 - 1) // 128) * 128
+        padded_sf_cols = (((k // SF_VEC_SIZE) + 3) // 4) * 4
+        num_scale_factors = padded_m * padded_sf_cols
+    else:
+        num_scale_factors = num_elements // SF_VEC_SIZE

You'll also need to thread is_sf_swizzled_layout through the run_bandwidth_sweep call site.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py` around lines 223 -
257, compute_bandwidth_tb_per_sec currently computes scale-factor bytes as
num_elements // SF_VEC_SIZE which ignores swizzled padding and thus overstates
TB/s for swizzled layout; modify compute_bandwidth_tb_per_sec to accept an
is_sf_swizzled_layout flag (or padded dims) and when true compute scale-factor
traffic using padded_m and padded_sf_cols (matching the swizzled write size used
in mxfp4_quantize.py for the MXFP4 path) instead of m * k / SF_VEC_SIZE, i.e.,
calculate num_scale_factors = padded_m * padded_sf_cols and include that in
problem_bytes; also thread the new is_sf_swizzled_layout argument through
run_bandwidth_sweep call sites so the bandwidth helper knows when to use padded
counts.

138-158: ⚠️ Potential issue | 🟠 Major

Exclude non-bitwise-equal cases from the MXFP4 timing sweep.

This has the same hole as the NVFP4 benchmark: quant_match_pct and scale_match_pct are recorded, but the case still counts as verified if cosine stays above 0.9. That makes the benchmark tables look valid even when the backends diverge.

✅ Proposed fix
         # Check backend agreement
         quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100
         scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100
+        if not torch.equal(quant_cuda, quant_cute) or not torch.equal(
+            scale_cuda, scale_cute
+        ):
+            return (
+                False,
+                f"Backend mismatch: quant={quant_match_pct:.1f}%, scale={scale_match_pct:.1f}%",
+                quant_match_pct,
+                scale_match_pct,
+            )

         # FP4 quantization should have cosine similarity > 0.9
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py` around lines 138 -
158, The current verification returns success whenever cosine similarity
(cos_sim_cuda or cos_sim_cute) >= 0.9 even if quantized outputs differ; change
the logic to require bitwise-equal quantization and scales before marking a case
as verified: compute quant_match_pct and scale_match_pct and if either is <
100.0, return False (or exclude from timing sweep) with a clear message
including quant_match_pct and scale_match_pct, otherwise continue to the cosine
checks; update the block that currently checks cos_sim_cuda/cos_sim_cute so that
the bitwise-equality check (quant_match_pct==100 and scale_match_pct==100) is
performed first.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_mxfp8_quantize_backend_comparison.py`:
- Around line 98-129: The comparison is wrong: scale_match_pct compares scale
tensors with float8 semantics and the cosine similarity uses raw FP8 payloads
instead of dequantized floats. Fix by comparing raw uint8 bytes for both
quantized payloads and scale carriers (use scale_cuda.view(torch.uint8) and
scale_cute.view(torch.uint8) to compute scale_match_pct), and compute cosine
similarity on dequantized outputs (dequantize quant_cuda and quant_cute using
their corresponding scale_cuda/scale_cute and the FP8 format—don’t just
.to(torch.float32) on the raw bytes; produce dq_cuda and dq_cute as true float32
reconstructions before calling torch.nn.functional.cosine_similarity). Ensure
you keep the existing variable names (quant_cuda, quant_cute, scale_cuda,
scale_cute, dq_cuda, dq_cute) so the rest of the function uses the corrected
values.

In `@benchmarks/bench_nvfp4_quantize_backend_comparison.py`:
- Around line 141-161: After computing quant_match_pct and scale_match_pct, add
a strict backend-agreement check that fails the verification if either
percentage is less than 100; specifically, if quant_match_pct < 100 or
scale_match_pct < 100 return (False, f"Backend mismatch:
quant_match_pct={quant_match_pct:.4f}%, scale_match_pct={scale_match_pct:.4f}%",
quant_match_pct, scale_match_pct). Keep this check alongside the existing
cosine-threshold checks (using cos_sim_cuda and cos_sim_cute) so the function
only returns success when both roundtrip quality and bitwise agreement between
quant_cuda and quant_cute (and scales) are satisfied.

In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 668-685: The reshape of scale_output after kernel execution uses
num_sf_blocks_per_row but the swizzled path and the allocation use
padded_sf_cols (scale_output_size = padded_m * padded_sf_cols), causing a
runtime size mismatch for 4-way padded SF columns; update the reshape to use
padded_sf_cols instead of num_sf_blocks_per_row (i.e., scale_output =
scale_output.reshape(-1, padded_sf_cols)) and ensure this change is applied
alongside references to padded_m/padded_sf_cols around kernel_fn and
scale_output allocation so the swizzled scale layout is consistent.

In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 114-116: The warp-count computation in
_compute_optimal_warps_for_k() always uses the constant SF_BLOCKS_PER_WARP (16)
when computing gcd/divisibility, but the legacy path should use
SF_BLOCKS_PER_WARP_SMALL (8) when use_2t_per_sf is False; change the logic so
the function selects the active sf_blocks_per_warp (e.g., set sf_blocks_per_warp
= SF_BLOCKS_PER_WARP if use_2t_per_sf else SF_BLOCKS_PER_WARP_SMALL) and use
that variable in the gcd calculation and any subsequent divisibility/warp-count
math (replace uses of SF_BLOCKS_PER_WARP in the gcd_val and warp derivation with
sf_blocks_per_warp) so rows_per_block and warp/thread counts are computed
correctly for K like 3072.

---

Outside diff comments:
In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py`:
- Around line 223-257: compute_bandwidth_tb_per_sec currently computes
scale-factor bytes as num_elements // SF_VEC_SIZE which ignores swizzled padding
and thus overstates TB/s for swizzled layout; modify
compute_bandwidth_tb_per_sec to accept an is_sf_swizzled_layout flag (or padded
dims) and when true compute scale-factor traffic using padded_m and
padded_sf_cols (matching the swizzled write size used in mxfp4_quantize.py for
the MXFP4 path) instead of m * k / SF_VEC_SIZE, i.e., calculate
num_scale_factors = padded_m * padded_sf_cols and include that in problem_bytes;
also thread the new is_sf_swizzled_layout argument through run_bandwidth_sweep
call sites so the bandwidth helper knows when to use padded counts.
- Around line 138-158: The current verification returns success whenever cosine
similarity (cos_sim_cuda or cos_sim_cute) >= 0.9 even if quantized outputs
differ; change the logic to require bitwise-equal quantization and scales before
marking a case as verified: compute quant_match_pct and scale_match_pct and if
either is < 100.0, return False (or exclude from timing sweep) with a clear
message including quant_match_pct and scale_match_pct, otherwise continue to the
cosine checks; update the block that currently checks cos_sim_cuda/cos_sim_cute
so that the bitwise-equality check (quant_match_pct==100 and
scale_match_pct==100) is performed first.

In `@flashinfer/quantization/kernels/__init__.py`:
- Around line 44-52: The package __all__ list is missing
NVFP4QuantizeLinearKernel which prevents re-exporting it; update the __all__ in
the kernels package to include "NVFP4QuantizeLinearKernel" alongside the other
symbols (e.g., add "NVFP4QuantizeLinearKernel" to the __all__ list that
currently contains "NVFP4QuantizeSwizzledKernel", "nvfp4_quantize_cute_dsl",
etc.) so that from flashinfer.quantization.kernels import
NVFP4QuantizeLinearKernel works as expected.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9aa2c7e3-005c-4fea-94b0-406037e2f3b6

📥 Commits

Reviewing files that changed from the base of the PR and between 31b63bc and 8a9545c.

📒 Files selected for processing (11)
  • benchmarks/bench_mxfp4_quantize_backend_comparison.py
  • benchmarks/bench_mxfp8_quantize_backend_comparison.py
  • benchmarks/bench_nvfp4_quantize_backend_comparison.py
  • flashinfer/quantization/kernels/__init__.py
  • flashinfer/quantization/kernels/mxfp4_quantize.py
  • flashinfer/quantization/kernels/mxfp8_quantize.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py
  • flashinfer/quantization/quantization_cute_dsl_utils.py
  • tests/utils/test_fp4_quantize.py
  • tests/utils/test_fp4_quantize_padding.py
  • tests/utils/test_fp8_quantize.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 384-385: The swizzled-path calculation of threads_per_row uses
num_sf_blocks_per_row so sf_col_idx never iterates into the padded SF columns
(indices num_sf_blocks_per_row .. padded_sf_cols-1), leaving those padding
columns uninitialized (scale_output from torch.empty). Update the swizzled
path(s) where threads_per_row is set (referenced by threads_per_row,
num_sf_blocks_per_row, _threads_per_sf) to size rows by padded_sf_cols (i.e.,
use padded_sf_cols * _threads_per_sf) so sf_col_idx loops cover padded_sf_cols,
and ensure the padding-column clear block and any zeroing logic that uses
sf_col_idx and padded_sf_cols runs for those extra columns; also ensure
scale_output is allocated/initialized accordingly rather than relying on
torch.empty leaving bytes undefined.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 53a4f10b-27c1-47c7-b4e6-a08b0bf9046c

📥 Commits

Reviewing files that changed from the base of the PR and between 8a9545c and 0eda88e.

📒 Files selected for processing (4)
  • benchmarks/bench_mxfp4_quantize_backend_comparison.py
  • flashinfer/quantization/kernels/mxfp8_quantize.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py
  • tests/utils/test_fp8_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/utils/test_fp8_quantize.py
  • benchmarks/bench_mxfp4_quantize_backend_comparison.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/quantization/kernels/mxfp8_quantize.py (1)

377-390: ⚠️ Potential issue | 🔴 Critical

Require exact divisibility before using the multi-row swizzled mapping.

If _compute_optimal_warps() falls back to 16 warps, Lines 389-390 can still select the small-K path even when col_units_per_block is not divisible by num_sf_blocks_per_row. For K=1056 (33 SF blocks), Line 574 then creates a partial extra row (512 / 66 threads-per-row in 2T/SF, 512 / 132 in 4T/SF), so those tail threads alias the next row batch’s first row. Gate the multi-row path on exact divisibility and otherwise fall back to needs_col_loop=True.

🔧 Suggested fix
         # rows_per_block = col_units_per_block // num_sf_blocks_per_row
-        # With optimal warps, this should divide evenly for small K
-        if self.num_sf_blocks_per_row <= col_units_per_block:
+        # Multi-row processing requires exact row tiling; otherwise the tail
+        # threads spill into a partial extra row and overlap the next row batch.
+        if (
+            self.num_sf_blocks_per_row <= col_units_per_block
+            and col_units_per_block % self.num_sf_blocks_per_row == 0
+        ):
             self.rows_per_block = col_units_per_block // self.num_sf_blocks_per_row
             self.needs_col_loop = False
         else:
             self.rows_per_block = 1
             self.needs_col_loop = True

Also applies to: 574-577

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/mxfp8_quantize.py` around lines 377 - 390,
The multi-row swizzle path currently assumes col_units_per_block divides
num_sf_blocks_per_row; update the gating logic in the block that computes
self.warps_per_block (call site uses _compute_optimal_warps) to require exact
divisibility before enabling rows_per_block: check (col_units_per_block %
self.num_sf_blocks_per_row == 0) and only then set self.rows_per_block =
col_units_per_block // self.num_sf_blocks_per_row; otherwise do not set
rows_per_block and force the fallback by setting self.needs_col_loop = True (and
ensure any later code that relies on rows_per_block uses the fallback when
needs_col_loop is True). This change addresses aliasing when
_compute_optimal_warps returns a fallback warp count (e.g., 16) for K values
like 1056.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 377-390: The multi-row swizzle path currently assumes
col_units_per_block divides num_sf_blocks_per_row; update the gating logic in
the block that computes self.warps_per_block (call site uses
_compute_optimal_warps) to require exact divisibility before enabling
rows_per_block: check (col_units_per_block % self.num_sf_blocks_per_row == 0)
and only then set self.rows_per_block = col_units_per_block //
self.num_sf_blocks_per_row; otherwise do not set rows_per_block and force the
fallback by setting self.needs_col_loop = True (and ensure any later code that
relies on rows_per_block uses the fallback when needs_col_loop is True). This
change addresses aliasing when _compute_optimal_warps returns a fallback warp
count (e.g., 16) for K values like 1056.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7ea1b97-a0cb-4167-aa30-cb8db8e97129

📥 Commits

Reviewing files that changed from the base of the PR and between 0eda88e and 07ba503.

📒 Files selected for processing (1)
  • flashinfer/quantization/kernels/mxfp8_quantize.py

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 27, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !468 has been created, and the CI pipeline #47118274 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47118274: 13/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants