perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels#2904
perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels#2904bkryu wants to merge 8 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds layout-specific MXFP4/MXFP8/NVFP4 CuTe-DSL quantize kernels (linear vs swizzled), threads/layout selection and compile-time modes, updates quantization utilities, expands benchmarks to run both layouts (bandwidth and comparison modes), and broadens tests to cover new shapes and backend gating. Changes
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces significant optimizations and structural improvements to the MXFP4, MXFP8, and NVFP4 quantization kernels in the CuTe-DSL backend. Key changes include the implementation of a dual-path optimization strategy (linear vs. swizzled layouts) to improve thread utilization, the addition of correctness verification in benchmark suites, and the introduction of a new NVFP4 benchmark. I have identified an inconsistency in the run_benchmark_sweep function docstring where a no_verify parameter is documented but missing from the function signature.
| k_values: List of K dimensions to benchmark | ||
| dtype: Input dtype | ||
| is_sf_swizzled_layout: Whether to use swizzled scale factor layout | ||
| no_verify: Skip correctness verification for pure timing runs |
There was a problem hiding this comment.
The no_verify parameter is mentioned in the docstring but is not present in the function signature. This creates an inconsistency between the documentation and the actual function API.
| no_verify: Skip correctness verification for pure timing runs | |
| is_sf_swizzled_layout: Whether to use swizzled scale factor layout |
There was a problem hiding this comment.
Good point. deleting in followup commit
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/quantization/kernels/__init__.py (1)
44-52:⚠️ Potential issue | 🟡 MinorRe-export
NVFP4QuantizeLinearKernelhere.
flashinfer.quantization.kernels.nvfp4_quantizenow publishesNVFP4QuantizeLinearKernel, but this package surface still omits it. That leavesfrom flashinfer.quantization.kernels import NVFP4QuantizeLinearKernelbroken even though the layout split is now public.🔧 Proposed fix
from .nvfp4_quantize import ( + NVFP4QuantizeLinearKernel, NVFP4QuantizeSwizzledKernel, nvfp4_quantize_cute_dsl, ) @@ + "NVFP4QuantizeLinearKernel", "NVFP4QuantizeSwizzledKernel", "nvfp4_quantize_cute_dsl", ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/kernels/__init__.py` around lines 44 - 52, The package __all__ list is missing NVFP4QuantizeLinearKernel which prevents re-exporting it; update the __all__ in the kernels package to include "NVFP4QuantizeLinearKernel" alongside the other symbols (e.g., add "NVFP4QuantizeLinearKernel" to the __all__ list that currently contains "NVFP4QuantizeSwizzledKernel", "nvfp4_quantize_cute_dsl", etc.) so that from flashinfer.quantization.kernels import NVFP4QuantizeLinearKernel works as expected.benchmarks/bench_mxfp4_quantize_backend_comparison.py (2)
223-257:⚠️ Potential issue | 🟠 MajorCount swizzled padding in the bandwidth numerator.
This helper always treats scale-factor traffic as
m * k / 32, but the swizzled MXFP4 path writespadded_m * padded_sf_colsbytes inflashinfer/quantization/kernels/mxfp4_quantize.pyLines 668-670. The reported TB/s is therefore inflated, especially for smallM.📏 Proposed fix
-def compute_bandwidth_tb_per_sec( - m: int, k: int, dtype: torch.dtype, time_ms: float +def compute_bandwidth_tb_per_sec( + m: int, + k: int, + dtype: torch.dtype, + time_ms: float, + is_sf_swizzled_layout: bool, ) -> float: @@ - num_scale_factors = num_elements // SF_VEC_SIZE + if is_sf_swizzled_layout: + padded_m = ((m + 128 - 1) // 128) * 128 + padded_sf_cols = (((k // SF_VEC_SIZE) + 3) // 4) * 4 + num_scale_factors = padded_m * padded_sf_cols + else: + num_scale_factors = num_elements // SF_VEC_SIZEYou'll also need to thread
is_sf_swizzled_layoutthrough therun_bandwidth_sweepcall site.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py` around lines 223 - 257, compute_bandwidth_tb_per_sec currently computes scale-factor bytes as num_elements // SF_VEC_SIZE which ignores swizzled padding and thus overstates TB/s for swizzled layout; modify compute_bandwidth_tb_per_sec to accept an is_sf_swizzled_layout flag (or padded dims) and when true compute scale-factor traffic using padded_m and padded_sf_cols (matching the swizzled write size used in mxfp4_quantize.py for the MXFP4 path) instead of m * k / SF_VEC_SIZE, i.e., calculate num_scale_factors = padded_m * padded_sf_cols and include that in problem_bytes; also thread the new is_sf_swizzled_layout argument through run_bandwidth_sweep call sites so the bandwidth helper knows when to use padded counts.
138-158:⚠️ Potential issue | 🟠 MajorExclude non-bitwise-equal cases from the MXFP4 timing sweep.
This has the same hole as the NVFP4 benchmark:
quant_match_pctandscale_match_pctare recorded, but the case still counts as verified if cosine stays above 0.9. That makes the benchmark tables look valid even when the backends diverge.✅ Proposed fix
# Check backend agreement quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100 scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + if not torch.equal(quant_cuda, quant_cute) or not torch.equal( + scale_cuda, scale_cute + ): + return ( + False, + f"Backend mismatch: quant={quant_match_pct:.1f}%, scale={scale_match_pct:.1f}%", + quant_match_pct, + scale_match_pct, + ) # FP4 quantization should have cosine similarity > 0.9🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py` around lines 138 - 158, The current verification returns success whenever cosine similarity (cos_sim_cuda or cos_sim_cute) >= 0.9 even if quantized outputs differ; change the logic to require bitwise-equal quantization and scales before marking a case as verified: compute quant_match_pct and scale_match_pct and if either is < 100.0, return False (or exclude from timing sweep) with a clear message including quant_match_pct and scale_match_pct, otherwise continue to the cosine checks; update the block that currently checks cos_sim_cuda/cos_sim_cute so that the bitwise-equality check (quant_match_pct==100 and scale_match_pct==100) is performed first.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_mxfp8_quantize_backend_comparison.py`:
- Around line 98-129: The comparison is wrong: scale_match_pct compares scale
tensors with float8 semantics and the cosine similarity uses raw FP8 payloads
instead of dequantized floats. Fix by comparing raw uint8 bytes for both
quantized payloads and scale carriers (use scale_cuda.view(torch.uint8) and
scale_cute.view(torch.uint8) to compute scale_match_pct), and compute cosine
similarity on dequantized outputs (dequantize quant_cuda and quant_cute using
their corresponding scale_cuda/scale_cute and the FP8 format—don’t just
.to(torch.float32) on the raw bytes; produce dq_cuda and dq_cute as true float32
reconstructions before calling torch.nn.functional.cosine_similarity). Ensure
you keep the existing variable names (quant_cuda, quant_cute, scale_cuda,
scale_cute, dq_cuda, dq_cute) so the rest of the function uses the corrected
values.
In `@benchmarks/bench_nvfp4_quantize_backend_comparison.py`:
- Around line 141-161: After computing quant_match_pct and scale_match_pct, add
a strict backend-agreement check that fails the verification if either
percentage is less than 100; specifically, if quant_match_pct < 100 or
scale_match_pct < 100 return (False, f"Backend mismatch:
quant_match_pct={quant_match_pct:.4f}%, scale_match_pct={scale_match_pct:.4f}%",
quant_match_pct, scale_match_pct). Keep this check alongside the existing
cosine-threshold checks (using cos_sim_cuda and cos_sim_cute) so the function
only returns success when both roundtrip quality and bitwise agreement between
quant_cuda and quant_cute (and scales) are satisfied.
In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 668-685: The reshape of scale_output after kernel execution uses
num_sf_blocks_per_row but the swizzled path and the allocation use
padded_sf_cols (scale_output_size = padded_m * padded_sf_cols), causing a
runtime size mismatch for 4-way padded SF columns; update the reshape to use
padded_sf_cols instead of num_sf_blocks_per_row (i.e., scale_output =
scale_output.reshape(-1, padded_sf_cols)) and ensure this change is applied
alongside references to padded_m/padded_sf_cols around kernel_fn and
scale_output allocation so the swizzled scale layout is consistent.
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 114-116: The warp-count computation in
_compute_optimal_warps_for_k() always uses the constant SF_BLOCKS_PER_WARP (16)
when computing gcd/divisibility, but the legacy path should use
SF_BLOCKS_PER_WARP_SMALL (8) when use_2t_per_sf is False; change the logic so
the function selects the active sf_blocks_per_warp (e.g., set sf_blocks_per_warp
= SF_BLOCKS_PER_WARP if use_2t_per_sf else SF_BLOCKS_PER_WARP_SMALL) and use
that variable in the gcd calculation and any subsequent divisibility/warp-count
math (replace uses of SF_BLOCKS_PER_WARP in the gcd_val and warp derivation with
sf_blocks_per_warp) so rows_per_block and warp/thread counts are computed
correctly for K like 3072.
---
Outside diff comments:
In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py`:
- Around line 223-257: compute_bandwidth_tb_per_sec currently computes
scale-factor bytes as num_elements // SF_VEC_SIZE which ignores swizzled padding
and thus overstates TB/s for swizzled layout; modify
compute_bandwidth_tb_per_sec to accept an is_sf_swizzled_layout flag (or padded
dims) and when true compute scale-factor traffic using padded_m and
padded_sf_cols (matching the swizzled write size used in mxfp4_quantize.py for
the MXFP4 path) instead of m * k / SF_VEC_SIZE, i.e., calculate
num_scale_factors = padded_m * padded_sf_cols and include that in problem_bytes;
also thread the new is_sf_swizzled_layout argument through run_bandwidth_sweep
call sites so the bandwidth helper knows when to use padded counts.
- Around line 138-158: The current verification returns success whenever cosine
similarity (cos_sim_cuda or cos_sim_cute) >= 0.9 even if quantized outputs
differ; change the logic to require bitwise-equal quantization and scales before
marking a case as verified: compute quant_match_pct and scale_match_pct and if
either is < 100.0, return False (or exclude from timing sweep) with a clear
message including quant_match_pct and scale_match_pct, otherwise continue to the
cosine checks; update the block that currently checks cos_sim_cuda/cos_sim_cute
so that the bitwise-equality check (quant_match_pct==100 and
scale_match_pct==100) is performed first.
In `@flashinfer/quantization/kernels/__init__.py`:
- Around line 44-52: The package __all__ list is missing
NVFP4QuantizeLinearKernel which prevents re-exporting it; update the __all__ in
the kernels package to include "NVFP4QuantizeLinearKernel" alongside the other
symbols (e.g., add "NVFP4QuantizeLinearKernel" to the __all__ list that
currently contains "NVFP4QuantizeSwizzledKernel", "nvfp4_quantize_cute_dsl",
etc.) so that from flashinfer.quantization.kernels import
NVFP4QuantizeLinearKernel works as expected.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9aa2c7e3-005c-4fea-94b0-406037e2f3b6
📒 Files selected for processing (11)
benchmarks/bench_mxfp4_quantize_backend_comparison.pybenchmarks/bench_mxfp8_quantize_backend_comparison.pybenchmarks/bench_nvfp4_quantize_backend_comparison.pyflashinfer/quantization/kernels/__init__.pyflashinfer/quantization/kernels/mxfp4_quantize.pyflashinfer/quantization/kernels/mxfp8_quantize.pyflashinfer/quantization/kernels/nvfp4_quantize.pyflashinfer/quantization/quantization_cute_dsl_utils.pytests/utils/test_fp4_quantize.pytests/utils/test_fp4_quantize_padding.pytests/utils/test_fp8_quantize.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 384-385: The swizzled-path calculation of threads_per_row uses
num_sf_blocks_per_row so sf_col_idx never iterates into the padded SF columns
(indices num_sf_blocks_per_row .. padded_sf_cols-1), leaving those padding
columns uninitialized (scale_output from torch.empty). Update the swizzled
path(s) where threads_per_row is set (referenced by threads_per_row,
num_sf_blocks_per_row, _threads_per_sf) to size rows by padded_sf_cols (i.e.,
use padded_sf_cols * _threads_per_sf) so sf_col_idx loops cover padded_sf_cols,
and ensure the padding-column clear block and any zeroing logic that uses
sf_col_idx and padded_sf_cols runs for those extra columns; also ensure
scale_output is allocated/initialized accordingly rather than relying on
torch.empty leaving bytes undefined.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 53a4f10b-27c1-47c7-b4e6-a08b0bf9046c
📒 Files selected for processing (4)
benchmarks/bench_mxfp4_quantize_backend_comparison.pyflashinfer/quantization/kernels/mxfp8_quantize.pyflashinfer/quantization/kernels/nvfp4_quantize.pytests/utils/test_fp8_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/utils/test_fp8_quantize.py
- benchmarks/bench_mxfp4_quantize_backend_comparison.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/quantization/kernels/mxfp8_quantize.py (1)
377-390:⚠️ Potential issue | 🔴 CriticalRequire exact divisibility before using the multi-row swizzled mapping.
If
_compute_optimal_warps()falls back to 16 warps, Lines 389-390 can still select the small-K path even whencol_units_per_blockis not divisible bynum_sf_blocks_per_row. ForK=1056(33SF blocks), Line 574 then creates a partial extra row (512 / 66threads-per-row in 2T/SF,512 / 132in 4T/SF), so those tail threads alias the next row batch’s first row. Gate the multi-row path on exact divisibility and otherwise fall back toneeds_col_loop=True.🔧 Suggested fix
# rows_per_block = col_units_per_block // num_sf_blocks_per_row - # With optimal warps, this should divide evenly for small K - if self.num_sf_blocks_per_row <= col_units_per_block: + # Multi-row processing requires exact row tiling; otherwise the tail + # threads spill into a partial extra row and overlap the next row batch. + if ( + self.num_sf_blocks_per_row <= col_units_per_block + and col_units_per_block % self.num_sf_blocks_per_row == 0 + ): self.rows_per_block = col_units_per_block // self.num_sf_blocks_per_row self.needs_col_loop = False else: self.rows_per_block = 1 self.needs_col_loop = TrueAlso applies to: 574-577
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/kernels/mxfp8_quantize.py` around lines 377 - 390, The multi-row swizzle path currently assumes col_units_per_block divides num_sf_blocks_per_row; update the gating logic in the block that computes self.warps_per_block (call site uses _compute_optimal_warps) to require exact divisibility before enabling rows_per_block: check (col_units_per_block % self.num_sf_blocks_per_row == 0) and only then set self.rows_per_block = col_units_per_block // self.num_sf_blocks_per_row; otherwise do not set rows_per_block and force the fallback by setting self.needs_col_loop = True (and ensure any later code that relies on rows_per_block uses the fallback when needs_col_loop is True). This change addresses aliasing when _compute_optimal_warps returns a fallback warp count (e.g., 16) for K values like 1056.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 377-390: The multi-row swizzle path currently assumes
col_units_per_block divides num_sf_blocks_per_row; update the gating logic in
the block that computes self.warps_per_block (call site uses
_compute_optimal_warps) to require exact divisibility before enabling
rows_per_block: check (col_units_per_block % self.num_sf_blocks_per_row == 0)
and only then set self.rows_per_block = col_units_per_block //
self.num_sf_blocks_per_row; otherwise do not set rows_per_block and force the
fallback by setting self.needs_col_loop = True (and ensure any later code that
relies on rows_per_block uses the fallback when needs_col_loop is True). This
change addresses aliasing when _compute_optimal_warps returns a fallback warp
count (e.g., 16) for K values like 1056.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e7ea1b97-a0cb-4167-aa30-cb8db8e97129
📒 Files selected for processing (1)
flashinfer/quantization/kernels/mxfp8_quantize.py
|
/bot run |
|
[FAILED] Pipeline #47118274: 13/20 passed |
📌 Description
Summary
Kernel changes
mxfp8_quantize.pymxfp4_quantize.pynvfp4_quantize.pyquantization_cute_dsl_utils.pyTest changes
test_fp4_quantize.pyandtest_fp8_quantize.py: Add more problem sizes.test_fp4_quantize_padding.py: Add both-backend parametrization and CUDA-vs-CuTe-DSL parity test for linear layout padding.Perf comparison between backends on B200
Click to see mxfp8 performance comparison
Linear (gmean 1.42x)
Swizzled (gmean 1.37x)
Click to see mxfp4 performance comparison
Linear (gmean 1.41x)
Swizzled (gmean 1.39x)
Click to see nvfp4 performance comparison
Linear (gmean 1.34x)
Swizzled (gmean 1.32x)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Tests