fix: Add global scale support and optional output allocation for RMSNorm+FP4Quant fusion kernels#2260
Conversation
📝 WalkthroughWalkthroughAdds an optional global_scale parameter across fused CuTe-DSL Add+RMSNorm+FP4 and RMSNorm+FP4 kernels, threads it through host→device pointer bindings and kernels, updates public APIs to accept/return global_scale and (y_fp4, block_scale), refactors benchmarks to report FUSED vs UNFUSED timings and speedup, and extends tests with global_scale-aware checks and two-tier tolerances. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Bench as Benchmark Runner
participant Host as Python Host API
participant Device as CUDA Kernel
participant Mem as Device Memory (y_fp4, block_scale)
Note right of Bench: benchmark uses fixed global_scale
Bench->>Host: call add_rmsnorm_fp4quant(x,r,w,..., global_scale)
Host->>Device: bind pointers (x,r,w,y,s,global_scale) & launch kernel
Device->>Mem: read global_scale, compute per-block scale & inv_scale
Device->>Mem: write y_fp4 and block_scale
Device-->>Host: kernel returns (y_fp4_ptr, block_scale_ptr)
Host-->>Bench: return (y_fp4, block_scale) -> record FUSED time
Note left of Bench: UNFUSED path runs Add -> RMSNorm -> FP4Quant with same global_scale
Bench->>Bench: compute speedup = unfused_us / fused_us
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (5)
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (1)
733-752: Prefix unused variables with underscore.The unpacked variables
y_fp4_gsandy_fp4_no_gsare intentionally unused since this test focuses on comparing block scale ratios. Prefixing with underscore clarifies intent and silences the linter.🔎 Proposed fix
- y_fp4_gs, block_scale_gs = rmsnorm_fp4quant( + _y_fp4_gs, block_scale_gs = rmsnorm_fp4quant( x, weight, global_scale=global_scale, eps=eps, block_size=block_size, is_sf_swizzled_layout=False, ) # Run without global_scale (global_scale=1.0) global_scale_one = torch.tensor([1.0], dtype=torch.float32, device="cuda") - y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant( + _y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant( x, weight, global_scale=global_scale_one, eps=eps, block_size=block_size, is_sf_swizzled_layout=False, )tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (1)
690-711: Prefix unused variables with underscore.Same as in the RMSNorm test file, the unpacked
y_fp4_gsandy_fp4_no_gsvariables are intentionally unused.🔎 Proposed fix
- y_fp4_gs, block_scale_gs = add_rmsnorm_fp4quant( + _y_fp4_gs, block_scale_gs = add_rmsnorm_fp4quant( x, r, weight, global_scale=global_scale, eps=eps, block_size=block_size, is_sf_swizzled_layout=False, ) # Run without global_scale (global_scale=1.0) global_scale_one = torch.tensor([1.0], dtype=torch.float32, device="cuda") - y_fp4_no_gs, block_scale_no_gs = add_rmsnorm_fp4quant( + _y_fp4_no_gs, block_scale_no_gs = add_rmsnorm_fp4quant( x, r, weight, global_scale=global_scale_one, eps=eps, block_size=block_size, is_sf_swizzled_layout=False, )flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (3)
2286-2288: Add validation for global_scale tensor properties.When
global_scaleis provided by the caller, there's no validation of its shape, dtype, or device. Invalid inputs could cause runtime errors or incorrect results.Proposed validation
Add validation after line 2388:
sm_version = get_sm_version(input.device) + # Validate global_scale if provided + if global_scale is not None: + assert global_scale.shape == (1,) or global_scale.numel() == 1, ( + f"global_scale must have shape (1,), got {global_scale.shape}" + ) + assert global_scale.dtype == torch.float32, ( + f"global_scale must have dtype torch.float32, got {global_scale.dtype}" + ) + assert global_scale.device == input.device, ( + f"global_scale device {global_scale.device} must match input device {input.device}" + ) + # Flatten to shape (1,) if needed + global_scale = global_scale.reshape(1) + # Allocate output tensors if not providedAlso applies to: 2442-2443
1755-1772: Consider warning when global_scale is provided with UE8M0 format.For UE8M0 (MXFP4, block_size=32),
global_scaleis silently ignored (lines 1755-1759, 2079-2083). Users might mistakenly provideglobal_scaleexpecting it to have an effect. Consider adding a validation or warning to make this explicit.You could add a check in the public API after line 2386:
if global_scale is not None and actual_scale_format == "ue8m0": import warnings warnings.warn( "global_scale is only supported for E4M3 format and will be ignored for UE8M0 (MXFP4)", UserWarning )Also applies to: 2079-2096
1377-1381: Clarify comment about "canceling global_scale".The phrase "to cancel global_scale" at lines 1377 and 1559 may be confusing. The purpose is to ensure the quantized intermediate values are computed using standard quantization (without global_scale in the quantization step), while global_scale is retained in the stored block scale. Consider rewording for clarity.
Suggested comment clarification
- # inv_scale = global_scale / scale_float to cancel global_scale + # inv_scale excludes global_scale from quantization computation + # so q = y / (max_abs / FP4_MAX), while scale_fp8 = global_scale * max_abs / FP4_MAX inv_scale = ( fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) * global_scale_val )Also applies to: 1559-1563
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.pybenchmarks/bench_cute_dsl_rmsnorm_fp4quant.pyflashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.pytests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.pytests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (5)
kernel(1091-1728)fmin_f32(208-220)cvt_f32_to_e4m3(462-482)fp8_e4m3_to_f32_and_rcp(486-518)get_cute_pointers(1751-1792)flashinfer/cute_dsl/utils.py (1)
make_ptr(175-223)
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (2)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (3)
compute_global_scale(83-105)llama_rms_norm(31-39)unswizzle_sf(851-884)flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
rmsnorm_fp4quant(1836-2013)
🪛 Ruff (0.14.10)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
74-74: Avoid specifying long messages outside the exception class
(TRY003)
143-146: Avoid specifying long messages outside the exception class
(TRY003)
690-690: Unpacked variable y_fp4_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
703-703: Unpacked variable y_fp4_no_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py
343-343: Do not catch blind exception: Exception
(BLE001)
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py
337-337: Do not catch blind exception: Exception
(BLE001)
350-350: Do not catch blind exception: Exception
(BLE001)
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
78-78: Avoid specifying long messages outside the exception class
(TRY003)
149-152: Avoid specifying long messages outside the exception class
(TRY003)
733-733: Unpacked variable y_fp4_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
745-745: Unpacked variable y_fp4_no_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (4)
1836-1846: LGTM on the API signature update.The function signature now correctly supports optional output allocation with
y_fp4andblock_scaleas optional parameters, and addsglobal_scalesupport. The return type is appropriately updated toTuple[torch.Tensor, torch.Tensor].
1940-1993: Auto-allocation logic is well-implemented.The logic correctly handles:
- 2D vs 3D input shapes
- Scale dtype selection based on format (UE8M0 → uint8, E4M3 → float8_e4m3fn)
- Swizzled layout size calculation with 128x4 tile pattern
- Default global_scale=1.0 when not provided
1252-1255: Verify global_scale read location for performance.The
global_scale_valis read from device memory inside the kernel loop. While this is correct for CUDA graph compatibility (as noted in the comment), reading it once per thread block rather than per-thread would be more efficient. However, the compiler likely optimizes this to a single load per warp.
1918-1925: LGTM on input reshaping.The 2D/3D input handling is correct. The
input_2dvariable is properly used for kernel execution, and.contiguous()is called at the point of use (line 2004).benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (3)
87-132: LGTM on global_scale integration in benchmarks.The
global_scaleparameter is properly threaded through to the fused kernel call. The benchmark structure correctly measures the fused kernel performance with the new parameter.
135-176: Good refactoring of unfused benchmark.Consolidating the unfused operations (rmsnorm + fp4_quantize) into a single timed function provides a more accurate comparison against the fused kernel. The
global_scaleis correctly passed tofp4_quantizefor NVFP4.
337-354: Blind exception handling is acceptable here.While static analysis flags
except Exception, this pattern is appropriate in benchmark code to ensure the suite continues running even if individual configurations fail. The error message includes sufficient context (batch_size, hidden_size, and exception details).tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (3)
90-111: LGTM on compute_global_scale helper.The formula
global_scale = (FP8_E4M3_MAX * FP4_E2M1_MAX) / max_abs(rmsnorm_output)correctly computes the optimal global scale to maximize dynamic range utilization for NVFP4 quantization.
114-152: Well-designed tiered tolerance check.The two-tier tolerance approach appropriately handles quantization noise:
- 99% of elements must match within tight tolerance (rtol=0.1, atol=0.1)
- 100% of elements must match within loose tolerance (rtol=0.5, atol=2.0)
This is more robust than a single tolerance threshold for FP4 quantized outputs.
1075-1121: Excellent test coverage for auto-allocation.The
TestAutoAllocationclass comprehensively covers:
- 2D and 3D input shapes
- NVFP4 (with global_scale) and MXFP4 formats
- Swizzled layout auto-allocation
- Equivalence between auto-allocated and pre-allocated paths
This ensures the new optional output allocation feature works correctly across all configurations.
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)
76-122: LGTM on global_scale integration.The
global_scaleparameter is correctly propagated to the fusedadd_rmsnorm_fp4quantkernel. The benchmark structure mirrors the RMSNorm-only benchmark file, maintaining consistency.
125-167: Good unfused benchmark implementation.The unfused benchmark correctly:
- Pre-allocates intermediate tensors (
h,y_normed) outside the timed region- Times the combined add + rmsnorm + fp4_quantize workflow
- Passes
global_scaletofp4_quantizefor NVFP4 consistency
343-348: Blind exception handling is acceptable in benchmark code.Similar to the other benchmark file, catching broad exceptions here ensures the benchmark suite continues running through all configurations. The error is logged with context.
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (3)
83-105: LGTM on compute_global_scale for add+rmsnorm.The function correctly computes global_scale based on the
rmsnorm(x + residual, weight)output, which is the appropriate reference for the add+rmsnorm fusion.
466-575: Excellent test coverage for fused vs separate comparison.The
TestFusedVsSeparateFP4Quantizeclass thoroughly validates:
- FP4 packed output byte-level matching (>95%)
- Block scale factor matching (>95%)
- Dequantized value closeness with tiered tolerance
- Both NVFP4 and MXFP4 formats
This ensures the fused kernel produces results consistent with the separate implementation.
1034-1078: Good test coverage for auto-allocation.The
TestAutoAllocationclass mirrors the RMSNorm test file structure, providing comprehensive coverage for the add+rmsnorm fusion kernel's auto-allocation feature.
| """Device kernel with cluster sync and Half2 SIMD. | ||
|
|
||
| mGlobalScale contains the global scale value. The kernel reads it and | ||
| computes 1/global_scale, which is multiplied with rstd to apply: | ||
| y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale | ||
| """ |
There was a problem hiding this comment.
Docstring inaccurately describes global_scale behavior.
The docstring states that the kernel "computes 1/global_scale, which is multiplied with rstd" and produces "y = rmsnorm(h, w) / global_scale". However, the implementation does not multiply rstd by 1/global_scale. Instead:
- RMSNorm is computed normally:
y = h * rstd * w(lines 1335-1350) global_scaleis incorporated into the stored block scale:scale_float = global_scale_val * max_abs / FP4_MAX(line 1371)- Quantization uses
inv_scale = FP4_MAX / max_abswhich cancels outglobal_scale(lines 1378-1381) - Net effect: dequantized output =
rmsnorm(h, w) * global_scale(multiplies, not divides)
The docstring should clarify that global_scale multiplies the dequantized output by being incorporated into the stored block scale, rather than claiming it divides the RMSNorm output.
Proposed docstring correction
- """Device kernel with cluster sync and Half2 SIMD.
-
- mGlobalScale contains the global scale value. The kernel reads it and
- computes 1/global_scale, which is multiplied with rstd to apply:
- y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
- """
+ """Device kernel with cluster sync and Half2 SIMD.
+
+ mGlobalScale contains the global scale value that is incorporated into
+ the block scale. RMSNorm is computed normally (y = h * rstd * w), but
+ the stored block scale is multiplied by global_scale. This results in
+ dequantized outputs being scaled: dequant_output = rmsnorm(h, w) * global_scale.
+ """📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """Device kernel with cluster sync and Half2 SIMD. | |
| mGlobalScale contains the global scale value. The kernel reads it and | |
| computes 1/global_scale, which is multiplied with rstd to apply: | |
| y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale | |
| """ | |
| """Device kernel with cluster sync and Half2 SIMD. | |
| mGlobalScale contains the global scale value that is incorporated into | |
| the block scale. RMSNorm is computed normally (y = h * rstd * w), but | |
| the stored block scale is multiplied by global_scale. This results in | |
| dequantized outputs being scaled: dequant_output = rmsnorm(h, w) * global_scale. | |
| """ |
🤖 Prompt for AI Agents
In flashinfer/cute_dsl/add_rmsnorm_fp4quant.py around lines 1109-1114, the
docstring incorrectly states the kernel computes 1/global_scale and divides the
RMSNorm output; instead, the implementation incorporates global_scale into the
stored block scale so the dequantized output is multiplied by global_scale.
Update the docstring to state that RMSNorm is computed normally (y = h * rstd *
w), that global_scale is factored into the stored block scale (scale_float =
global_scale_val * max_abs / FP4_MAX) and thus the dequantized result is
multiplied by global_scale, and remove the incorrect “1/global_scale” language;
reference the relevant implementation lines (≈1335-1350, 1371, 1378-1381) for
clarity.
| global_scale : torch.Tensor, optional | ||
| Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``. | ||
| If provided, the RMSNorm output is divided by this value before quantization: | ||
| ``y = rmsnorm(h, w) / global_scale`` where ``h = input + residual``. | ||
| This is used for NVFP4 format where a pre-computed global scale lifts | ||
| per-block scales into optimal dynamic range. | ||
| If ``None``, no global scaling is applied (equivalent to global_scale=1.0). | ||
| eps : float |
There was a problem hiding this comment.
Docstring incorrectly describes global_scale effect on quantization.
The docstring claims that "the RMSNorm output is divided by this value before quantization: y = rmsnorm(h, w) / global_scale". This is inaccurate. The implementation:
- Computes RMSNorm normally without applying global_scale to intermediate values
- Incorporates
global_scaleinto the stored block scale - Results in dequantized output =
rmsnorm(h, w) * global_scale(multiply, not divide)
The parameter description should clarify that global_scale adjusts the magnitude of dequantized outputs by being baked into the block scale, and that larger values produce larger outputs (not smaller).
Proposed docstring correction
global_scale : torch.Tensor, optional
Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
- If provided, the RMSNorm output is divided by this value before quantization:
- ``y = rmsnorm(h, w) / global_scale`` where ``h = input + residual``.
- This is used for NVFP4 format where a pre-computed global scale lifts
- per-block scales into optimal dynamic range.
+ If provided, this value is incorporated into the per-block scales for E4M3 format.
+ The effect is to scale the dequantized output: ``dequant = rmsnorm(h, w) * global_scale``
+ where ``h = input + residual``. This adjusts the magnitude of outputs without affecting
+ quantization granularity. Only used for E4M3 format; ignored for UE8M0 (MXFP4).
If ``None``, no global scaling is applied (equivalent to global_scale=1.0).🤖 Prompt for AI Agents
In flashinfer/cute_dsl/add_rmsnorm_fp4quant.py around lines 2328 to 2335, the
docstring incorrectly states that the RMSNorm output is divided by global_scale
before quantization; instead, global_scale is incorporated into the stored block
scale so the dequantized output is multiplied by global_scale (dequantized =
rmsnorm(h, w) * global_scale). Update the parameter description to explain that
providing global_scale bakes that factor into the block scale and increases the
magnitude of dequantized outputs (larger global_scale -> larger outputs), and
clarify that None means no global scaling (equivalent to global_scale=1.0).
There was a problem hiding this comment.
Code Review
This pull request is a high-quality contribution that enhances the rmsnorm_fp4quant and add_rmsnorm_fp4quant kernels with optional output allocation and support for a global scale factor. The implementation is robust, with the new features correctly integrated into the CuTe-DSL kernels, Python API, benchmarks, and tests.
The API changes make the functions more flexible by allowing automatic allocation of output tensors. The global scale support for NVFP4 quantization is correctly implemented by incorporating the scale into the block scale computation within the kernel, which is crucial for dynamic range management.
The benchmarks have been significantly improved by refactoring the unfused baseline to measure the entire pipeline, providing a more realistic performance comparison. The tests are exceptionally thorough, with new test classes validating auto-allocation, global scale correctness, and the equivalence between fused and separate execution paths. The introduction of a tiered tolerance assertion function is an excellent addition for robustly testing low-precision quantized outputs.
Overall, the changes are well-executed, well-tested, and improve both the functionality and usability of the fusion kernels. I have no specific comments on the code changes.
|
[SUCCESS] Pipeline #40672783: 12/20 passed |
1ba2ead to
c660708
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (3)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (3)
1096-1114: Kernel docstring still describes incorrect 1/global_scale behavior.Lines 1111–1114 claim the kernel “computes 1/global_scale, which is multiplied with rstd” and that
y = rmsnorm(h, w) / global_scale. The implementation doesn’t modifyrstdwith1/global_scale; it usesmGlobalScaleonly in E4M3 scale/inv_scale math, leaving RMSNorm unchanged and expecting dequantization to divide byglobal_scale.Align this docstring with the implemented behavior: RMSNorm is computed normally,
global_scaleis folded into block scales for NVFP4/E4M3, and consumers should divide byglobal_scalewhen dequantizing.
2281-2335: Clarify public API global_scale semantics (baked into block scales, undone at dequant).The
global_scaleparameter docstring still says “RMSNorm output is divided by this value before quantization:y = rmsnorm(h, w) / global_scale”, which doesn’t reflect the implementation:
- The fused kernel computes RMSNorm normally.
- For NVFP4/E4M3,
global_scaleis folded into the stored block scales andinv_scale.- Tests dequantize by multiplying FP4 bytes with the block scales and then dividing by
global_scaleto recover the RMSNorm output.To avoid confusion, please update this section to something along the lines of:
- global_scale is only used for E4M3 / NVFP4; it is ignored for UE8M0 (MXFP4).
- It is incorporated into the per-block scales (
scale ≈ global_scale * max_abs / FP4_MAX).- Downstream dequantization should divide by
global_scaleto reverse this factor; withglobal_scale=None(or 1.0), behavior matches prior semantics.
1016-1027: Update call docstring to match actual global_scale usage.The kernel no longer multiplies
rstdby1/global_scaleor computesy = rmsnorm(h) / global_scale; instead,global_scaleis read intomGlobalScaleand used only when forming E4M3 block scales andinv_scale. Conceptually, global_scale is baked into per-block scales, and downstream dequantization is expected to divide byglobal_scaleto recover the RMSNorm output.Please reword this docstring to describe:
- RMSNorm computed normally (
y = h * rstd * w),global_scaleincorporated into E4M3 block scales (scale ≈ global_scale * max_abs / FP4_MAX),- Dequantizers should divide by
global_scaleto undo this factor.
🧹 Nitpick comments (9)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
1939-1992: Consider adding shape validation for user-provided output tensors.The output allocation logic correctly handles all cases (2D/3D, swizzled/non-swizzled). However, when users provide pre-allocated
y_fp4orblock_scaletensors, there's no validation to ensure they have the correct shapes and dtypes.For example:
- If
y_fp4has shape(batch_size, hidden_size)instead of(batch_size, hidden_size // 2), the kernel will write to incorrect memory locations- If
block_scaleistorch.uint8butscale_format="e4m3", type mismatch will occur🔎 Proposed validation logic
Add validation before kernel launch (around line 1939):
+ # Validate user-provided output tensors + if y_fp4 is not None: + expected_shape = (batch_size, hidden_size // 2) if not is_3d else (B, S, hidden_size // 2) + if y_fp4.shape != expected_shape: + raise ValueError(f"y_fp4 shape mismatch: expected {expected_shape}, got {y_fp4.shape}") + if y_fp4.dtype != torch.uint8: + raise ValueError(f"y_fp4 dtype must be torch.uint8, got {y_fp4.dtype}") + + if block_scale is not None: + scale_dtype = torch.uint8 if actual_scale_format == "ue8m0" else torch.float8_e4m3fn + if block_scale.dtype != scale_dtype: + raise ValueError(f"block_scale dtype must be {scale_dtype} for {actual_scale_format}, got {block_scale.dtype}") + # Add shape validation based on layout mode + # Allocate output tensors if not provided if y_fp4 is None:
1252-1254: Clarifyglobal_scalebehavior for MXFP4 (UE8M0) in docstring.The implementation correctly incorporates
global_scalefor NVFP4 format (E4M3 scales) and intentionally ignores it for MXFP4 format (UE8M0 scales). However, the docstring is ambiguous about this limitation.The
global_scaleparameter documentation states: "This is used for NVFP4 format where a pre-computed global scale lifts per-block scales into optimal dynamic range." This implies the limitation but doesn't explicitly state thatglobal_scaleis ignored for MXFP4 whenblock_size=32orscale_format="ue8m0".Update the
global_scaleparameter documentation to explicitly note: "Note: This parameter is only used for NVFP4 format (block_size=16, E4M3 scales). For MXFP4 format (block_size=32, UE8M0 scales), global_scale is ignored."flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2368-2465: Auto-allocation and 2D/3D/global_scale handling in Python API looks sound.The reshaping to 2D, conditional allocation for 2D/3D and swizzled/unswizzled layouts, and default
global_scale=torch.ones(1, ...)all line up with the kernel’s expectations and the new tests (auto-allocation, swizzled, NVFP4/MXFP4, global_scale consistency).The only micro-optimization you might consider is avoiding redundant
.contiguous()calls on already-contiguous views, but that’s optional and not performance-critical.benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (1)
256-365: Narrow broad exception handling in benchmark loop.The
try/except Exception as e:blocks around fused and unfused timing will also swallow unexpected programming errors or misconfigurations, which can hide real issues.Consider catching more specific exceptions (e.g.,
RuntimeError,torch.cuda.OutOfMemoryError) and either re-raising others or at least logging them distinctly.tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (2)
42-147: Helper semantics are correct but duplicated across test modules.
dequantize_fp4_output,compute_global_scale, andassert_close_with_tiered_tolerancecorrectly model:
- UE8M0 vs E4M3 scale decoding,
- global_scale being baked into block scales and then undone by dividing at dequant,
- and a two-tier tolerance regime appropriate for FP4 noise.
The same patterns appear in the RMSNorm FP4 test file; consider factoring these into a shared helper module (e.g.,
tests/test_helpers/fp4_quantization.py) to avoid divergence over time.
666-727: Tight global_scale consistency check is good; mark unused y_fp4 variables.The test correctly asserts that
block_scalewithglobal_scaleis ~global_scaletimes larger than without, in line with the formulascale = global_scale * max_abs / FP4_MAX.Since
y_fp4_gsandy_fp4_no_gsare unused, consider renaming them to_y_fp4_gsand_y_fp4_no_gs(or unpacking only the second element) to satisfy linters without changing behavior.benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (1)
327-355: Consider narrowing broad exception handling in benchmark loop.As in the add+rmsnorm benchmark, the bare
except Exceptionblocks can hide unexpected programming errors.It would be safer to catch expected runtime issues explicitly (e.g., CUDA OOM) and optionally re-raise others, instead of treating all failures as a generic “FUSED/UNFUSED ERROR”.
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (2)
90-152: Global-scale computation and tiered tolerance helper are well-designed (but duplicated).
compute_global_scaleandassert_close_with_tiered_tolerancemirror the helpers in the add+RMS tests and are appropriate for:
- Choosing a
global_scalethat fits RMSNorm outputs into FP4 dynamic range.- Evaluating FP4 dequant results with a two-tier tolerance.
As noted in the add+RMS test file, consider extracting these shared helpers into a common test utility to reduce duplication.
711-768: Global-scale value consistency test is correct; mark unused y_fp4 variables.The test correctly checks that
block_scalewith global_scale is ~global_scaletimesblock_scalewithout, in accordance with the formula used inside the kernel.Since
y_fp4_gsandy_fp4_no_gsare not used, consider renaming them to_y_fp4_gsand_y_fp4_no_gs(or unpacking only the second return) to satisfy linters.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.pybenchmarks/bench_cute_dsl_rmsnorm_fp4quant.pyflashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.pytests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.pytests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.pytests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/cute_dsl/rmsnorm_fp4quant.pyflashinfer/cute_dsl/add_rmsnorm_fp4quant.py
🧬 Code graph analysis (4)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)
include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)flashinfer/norm.py (1)
rmsnorm(33-68)flashinfer/testing/utils.py (1)
bench_gpu_time(1484-1631)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
kernel(1096-2173)get_cute_pointers(2193-2239)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
kernel(1091-1728)get_cute_pointers(1751-1792)
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (3)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)
unfused_operation(143-155)sanity_check_outputs(170-253)compute_bandwidth_gb_s(37-73)flashinfer/norm.py (1)
rmsnorm(33-68)flashinfer/testing/utils.py (1)
bench_gpu_time(1484-1631)
🪛 Ruff (0.14.10)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
74-74: Avoid specifying long messages outside the exception class
(TRY003)
143-146: Avoid specifying long messages outside the exception class
(TRY003)
690-690: Unpacked variable y_fp4_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
703-703: Unpacked variable y_fp4_no_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py
343-343: Do not catch blind exception: Exception
(BLE001)
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
78-78: Avoid specifying long messages outside the exception class
(TRY003)
149-152: Avoid specifying long messages outside the exception class
(TRY003)
733-733: Unpacked variable y_fp4_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
745-745: Unpacked variable y_fp4_no_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py
337-337: Do not catch blind exception: Exception
(BLE001)
350-350: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (21)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
1010-1010: LGTM: Global scale parameter integration.The addition of
global_scale_ptrto the kernel signature and the creation ofmGlobalScaletensor are correctly implemented. The parameter is properly threaded through the host and device functions.Also applies to: 1063-1067, 1080-1080, 1097-1097
1768-1791: LGTM: Global scale pointer creation.The pointer creation logic correctly handles both compilation (dummy pointer) and runtime (actual tensor) paths for
global_scale. The alignment of 4 bytes for Float32 is appropriate.flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2218-2239: Global scale pointer wiring looks consistent with kernel expectations.Adding
global_scaleas a Float32 GMEM pointer inget_cute_pointersand threading it intocute.compileand the compiled closure matches the new kernel signature. The dummy pointer path andassumed_align=4are appropriate for the scalar.benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (4)
76-123: Fused benchmark wiring with global_scale is correct.Threading
global_scalethrough toadd_rmsnorm_fp4quantwhile keeping block_scale dtype consistent withblock_size(E4M3 vs UE8M0) matches the new API and tests. Median-of-bench_gpu_timeremains the right aggregation.
125-168: Unfused path correctly mirrors fused math and global_scale use.The unfused
torch.add+rmsnorm+fp4_quantizesequence, includingglobal_scalefor NVFP4 andsf_use_ue8m0=(block_size == 32), is aligned with how the fused kernel is exercised in tests. This looks like a valid baseline for speedup comparison.
170-253: Sanity-check path is consistent with fused/unfused semantics under global_scale.Using the same
global_scalefor both fused and separate paths and comparing FP4 bytes with a relaxed percentage threshold is a good practical validation of correctness given FP4 noise and different operation ordering.
371-381: Geomean speedup reporting against unfused baseline looks good.Collecting non-
Nonespeedups and reporting the geometric mean vs “unfused add + rmsnorm + fp4_quantize” matches the new fused-vs-unfused framing.tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (3)
219-231: Tiered tolerance usage in value-level comparisons is appropriate.Using
assert_close_with_tiered_tolerancewith tighter (rtol=0.3, atol=0.5) and looser (rtol=0.5, atol=2.0) thresholds for FP4 dequant results is a reasonable balance between strictness and the coarseness of 4‑bit quantization.
466-575: Fused vs separate NVFP4/MXFP4 comparisons robustly validate global_scale handling.The new
TestFusedVsSeparateFP4Quantizetests:
- Compare packed FP4 bytes and block scales between fused and separate paths.
- Use dequantization with optional
global_scaleand tiered tolerances.This is exactly what’s needed to ensure the fused kernel applies
global_scaleidentically tofp4_quantizefor both NVFP4 (E4M3) and MXFP4 (UE8M0).
1034-1260: Auto-allocation tests thoroughly exercise new API shapes and layouts.The
TestAutoAllocationclass verifies:
- 2D/3D NVFP4, MXFP4, and swizzled layouts.
- Correct shapes and dtypes of auto-allocated
y_fp4andblock_scale.- Equality between preallocated and auto-allocated results.
These provide strong coverage for the new
(y_fp4, block_scale)return semantics inadd_rmsnorm_fp4quant.benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (5)
87-133: CuTe-DSL fused benchmark correctly passes global_scale to rmsnorm_fp4quant.The extended
bench_cute_dslsignature and lambda correctly propagateglobal_scaleto the fused RMSNorm+FP4 kernel, while choosing block_scale dtype andscale_formatbased onblock_size. This aligns with the updated API and tests.
135-177: Unfused RMSNorm + fp4_quantize path matches fused semantics.The
unfused_operation(RMSNorm followed byfp4_quantize) and its use ofglobal_scalefor NVFP4,sf_use_ue8m0for MXFP4, andbench_gpu_timeover the combined op provide a solid baseline for fused speedups.
179-260: Sanity-check compares fused vs separate outputs under shared global_scale.Running both
rmsnorm_fp4quantand separatermsnorm + fp4_quantizewith the sameglobal_scaleand checking FP4 match percentage is a practical validation of the fused kernel’s correctness given FP4’s low precision.
278-287: Global_scale construction in benchmark is reasonable and consistent.Using
FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / 3.0as a fixed calibration-likeglobal_scalevalue is consistent with the test helper’s formulation and sufficient for benchmarking.
376-383: Geomean speedup vs unfused RMSNorm+fp4_quantize is reported correctly.Collecting finite speedups and printing the geometric mean relative to the unfused path reflects the intended performance comparison after the refactor.
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (6)
42-88: Dequantization and global_scale handling align with fused kernel math.
dequantize_fp4_outputcorrectly:
- Handles 2D and 3D FP4 layouts,
- Decodes E4M3 vs UE8M0 scales,
- Divides by
global_scalewhen provided, consistent with block scales that include a global_scale factor.This matches the quantization math used in
rmsnorm_fp4quantand the NVFP4 tests.
199-258: NVFP4 2D/3D tests properly exercise global_scale integration.The 2D and 3D NVFP4 tests:
- Compute
global_scalefrom reference RMSNorm outputs,- Run
rmsnorm_fp4quantwith that scale,- Dequantize with
dequantize_fp4_output(..., global_scale)and compare using tiered tolerances.This gives strong end-to-end validation of the new global_scale behavior.
444-510: Fused vs separate RMSNorm comparison correctly uses dequant + tiered tolerances.Comparing dequantized fused outputs (with global_scale) to separate
rmsnormresults usingassert_close_with_tiered_toleranceis the right level of check, avoiding dependence on internal scale computation details while still validating functional correctness.
512-621: Fused vs separate fp4_quantize tests robustly validate global_scale consistency.The NVFP4 and MXFP4 tests under
TestFusedVsSeparateFP4Quantize:
- Compare packed FP4 bytes and block-scale tensors between fused and standalone
fp4_quantize.- Also compare dequantized outputs with tiered tolerances.
These should catch any divergence in how global_scale is applied across paths.
781-837: Large hidden NVFP4 test properly exercises cluster sync path with global_scale.Sampling a subset of rows, computing
global_scale, and validating dequantized outputs against reference RMSNorm helps ensure the cluster-synchronized path behaves correctly under the new scaling scheme.
1075-1321: Auto-allocation tests for RMSNorm-only path mirror add+RMS semantics and look solid.
TestAutoAllocationcovers:
- 2D/3D NVFP4 with global_scale,
- MXFP4 without global_scale,
- Swizzled layouts,
- Parity between preallocated and auto-allocated outputs.
These tests effectively validate the new “optional outputs returning (y_fp4, block_scale)” behavior in
rmsnorm_fp4quant.
| y_fp4: torch.Tensor | None = None, | ||
| block_scale: torch.Tensor | None = None, | ||
| global_scale: torch.Tensor | None = None, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Get the function signature and context around the mentioned lines
sed -n '1835,1850p' flashinfer/cute_dsl/rmsnorm_fp4quant.pyRepository: flashinfer-ai/flashinfer
Length of output: 609
🏁 Script executed:
# Check the return statements and key implementation lines
sed -n '1902,1915p' flashinfer/cute_dsl/rmsnorm_fp4quant.pyRepository: flashinfer-ai/flashinfer
Length of output: 593
🏁 Script executed:
# Check line 2013 (return statement)
sed -n '2010,2020p' flashinfer/cute_dsl/rmsnorm_fp4quant.pyRepository: flashinfer-ai/flashinfer
Length of output: 208
🏁 Script executed:
# Search for existing calls to rmsnorm_fp4quant in the codebase
rg -nP --type=py 'rmsnorm_fp4quant\s*\(' -B1 -A3Repository: flashinfer-ai/flashinfer
Length of output: 19556
🏁 Script executed:
# Check if function has decorators (like @functools.cache or @flashinfer_api)
sed -n '1830,1845p' flashinfer/cute_dsl/rmsnorm_fp4quant.pyRepository: flashinfer-ai/flashinfer
Length of output: 448
Add @functools.cache decorator for module-level caching per coding guidelines.
The function is decorated with @flashinfer_api as required, but lacks the @functools.cache decorator specified in the coding guidelines for flashinfer/**/*.py files. This should be added to implement module-level caching and avoid recompilation.
Additionally, the API signature has been significantly changed:
y_fp4andblock_scaleare now optional (previously required)global_scaleparameter added- Function now returns
Tuple[torch.Tensor, torch.Tensor]instead ofNone
These changes are intentional and documented in the docstring. Existing code passing pre-allocated tensors will continue to work (Python allows ignoring return values), but usage patterns differ from the previous in-place operation model.
🤖 Prompt for AI Agents
In flashinfer/cute_dsl/rmsnorm_fp4quant.py around lines 1839-1841, the function
is missing the required module-level caching decorator; add @functools.cache
immediately above the existing @flashinfer_api decorator to enable module-level
caching and prevent recompilation, and ensure functools.cache is imported at the
module level (add "import functools" if not present); do not change the function
signature or behavior beyond adding the decorator.
| global_scale : torch.Tensor, optional | ||
| Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``. | ||
| If provided, the RMSNorm output is divided by this value before quantization: | ||
| ``y = rmsnorm(x, w) / global_scale``. This is used for NVFP4 format where | ||
| a pre-computed global scale lifts per-block scales into optimal dynamic range. | ||
| If ``None``, no global scaling is applied (equivalent to global_scale=1.0). |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Clarify global_scale documentation for MXFP4 format.
The docstring states that global_scale is "used for NVFP4 format" but doesn't explicitly mention that it's ignored for MXFP4 (block_size=32 with UE8M0 scale format). This could lead to confusion when users provide global_scale with block_size=32.
Recommend adding a note in the documentation:
🔎 Proposed documentation update
global_scale : torch.Tensor, optional
Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
If provided, the RMSNorm output is divided by this value before quantization:
``y = rmsnorm(x, w) / global_scale``. This is used for NVFP4 format where
a pre-computed global scale lifts per-block scales into optimal dynamic range.
+ **Note**: This parameter is only applicable for NVFP4 (block_size=16 with E4M3
+ scale format). It is ignored for MXFP4 (block_size=32 with UE8M0 scale format).
If ``None``, no global scaling is applied (equivalent to global_scale=1.0).This relates to the earlier major issue about validating or supporting global_scale for UE8M0.
Also applies to: 1910-1916
|
/bot run |
| Tuple[torch.Tensor, torch.Tensor] | ||
| A tuple of ``(y_fp4, block_scale)``: | ||
|
|
||
| - ``y_fp4``: Quantized FP4 values packed as uint8. |
There was a problem hiding this comment.
Can you use float4_e2m1fn_x2 instead for torch 2.8+?
There was a problem hiding this comment.
Didn't realize torch.float4_e2m1fn_x2 was available; thanks for pointing this out. Changed the output format (and unit tests accordingly) in the latest commits
| @cute.jit | ||
| def __call__( | ||
| self, | ||
| x_ptr: cute.Pointer, |
There was a problem hiding this comment.
With tvm-ffi enabled (https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html), we can pass cute.Tensor directly instead of cute.Pointer without overhead, I'll create a refactor PR later.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
1003-1014: global_scale kernel docstring does not match implemented behaviorThe device kernel docstring and
__call__comment state that the kernel “computes 1/global_scale” and appliesy = rmsnorm(x, w) / global_scale. In the E4M3 path, the implementation instead:
- Computes
scale_float = global_scale_val * max_abs / FP4_MAX,- Stores this (quantized) as the per-block scale, and
- Uses
inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) * global_scale_val ≈ FP4_MAX / max_abs, which is independent ofglobal_scale_val.This means:
- The FP4 codes (
q) are effectively the same as in the no-global-scale case.- The stored block scales are multiplied by
global_scale, so dequantizationq * block_scaleyields outputs ≈rmsnorm(x, w) * global_scale(not divided).For UE8M0,
global_scaleis ignored entirely, as expected.The docs should be updated to describe the actual behavior:
global_scaleis folded into the stored E4M3 block scales, leaving quantization codes unchanged while scaling dequantized outputs proportionally. The “1/global_scale” and “/ global_scale” language is misleading given the current math.Also applies to: 1063-1068, 1090-1108, 1252-1255, 1387-1391, 1396-1401, 1635-1655
1736-1792: Validate global_scale tensor device/dtype/shape in the Python APIThe pointer wiring and auto-allocation logic (
y_fp4andblock_scalefor 2D/3D, swizzled/unswizzled) look solid, and_get_compiled_kernelis correctly cached.However,
rmsnorm_fp4quantassumes:
global_scaleis on the same CUDA device asinput,- Has shape
(1,), and- Has dtype
torch.float32,but doesn’t enforce any of these before passing its
data_ptr()to the kernel as aFloat32gmem pointer. If a caller accidentally passes a CPU tensor, different shape, or different dtype, this will yield undefined behavior at the CUDA level rather than a clear Python-side error.Consider adding cheap upfront validation, e.g.:
assert global_scale.device == input.deviceassert global_scale.dtype == torch.float32assert global_scale.numel() == 1(or corresponding
ValueErrors) before the kernel launch.Also applies to: 1814-1832, 1939-1997, 2007-2017
♻️ Duplicate comments (3)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
1835-1846: Clarify global_scale API doc: effect and MXFP4 applicabilityThe
rmsnorm_fp4quantdocstring currently says:
- The RMSNorm output is “divided by”
global_scalebefore quantization (y = rmsnorm(x, w) / global_scale), and- Mentions NVFP4 usage but not that
global_scaleis ignored for MXFP4/UE8M0.Given the kernel implementation:
- For E4M3 (NVFP4),
global_scaleis incorporated into the stored block scales, leading to dequantized outputs proportional toglobal_scale(while the FP4 codes stay effectively unchanged).- For UE8M0 (MXFP4),
global_scaleis not used at all.The parameter docs and “Returns” section should be adjusted to:
- Describe that
global_scaleis folded into E4M3 block scales and affects the magnitude of dequantized outputs, not that the RMSNorm output is divided by it pre-quantization.- Explicitly state that
global_scaleis only applicable for NVFP4 (block_size=16, E4M3) and is ignored for MXFP4 (block_size=32, UE8M0).This aligns the public API contract with the actual kernel math and avoids confusion for users calibrating global scales.
Also applies to: 1879-1883, 1902-1909
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
1008-1027: Kernel/global_scale docstring still contradicts actual quantization behaviorThe
AddRMSNormFP4QuantKernelkernel and host__call__docstrings describe:
- Computing
1/global_scale, multiplying it intorstd, and- Effectively applying
y = rmsnorm(h, w) / global_scale.As in the standalone RMSNorm kernel, the implementation instead:
- Incorporates
global_scaleintoscale_float = global_scale_val * max_abs * fp4_max_rcpfor E4M3 (NVFP4),- Uses
inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) * global_scale_val, which cancelsglobal_scaleso FP4 codes are unchanged relative to the no-global-scale case, and- Leaves UE8M0 (MXFP4) ignoring
global_scaleentirely.Net effect: dequantized outputs (
q * block_scale) are scaled byglobal_scalefor E4M3, not divided by it.Please update these kernel-level docstrings to reflect:
- That
global_scaleis folded into E4M3 block scales (and ignored for UE8M0),- That it scales the dequantized result rather than altering
rstddirectly.Also applies to: 1068-1072, 1095-1114, 1274-1277, 1369-1382, 1511-1563, 1752-1772, 2076-2096
2281-2335: add_rmsnorm_fp4quant: align global_scale API doc with behavior and validate inputsThe high-level API changes are generally good:
- Optional
y_fp4/block_scalewith correct auto-allocation for 2D/3D and swizzled/unswizzled layouts.- Return of
(y_fp4, block_scale)is consistent with the new fused API design.- Use of
torch.float4_e2m1fn_x2and appropriate scale dtypes matches the kernels.Two issues to address:
Docstring semantics and MXFP4 applicability
The
global_scaleparameter doc currently claims:
- The RMSNorm output is divided by
global_scalebefore quantization (y = rmsnorm(h, w) / global_scale), and- Does not state that MXFP4 ignores
global_scale.In reality (E4M3/NVFP4):
global_scaleis baked into the block scales; FP4 codes are unchanged, and dequantized outputs scale withglobal_scale.- For UE8M0/MXFP4,
global_scaleis not used.The parameter description should be updated accordingly, and explicitly note that
global_scaleis only meaningful for NVFP4/E4M3 and ignored for MXFP4/UE8M0.Runtime validation of global_scale tensor
As with
rmsnorm_fp4quant, the function assumesglobal_scaleis a 1-elementtorch.float32tensor on the same CUDA device asinput, but does not check:
- Device equality,
- Dtype (
torch.float32), or- Shape/numel (1).
Passing a CPU tensor or wrong dtype would result in the kernel reading an invalid device pointer. Adding simple validation (or coercing to the correct device/dtype with a small copy) before invoking
tensor_apiwould make this API much safer.Also applies to: 2353-2367, 2368-2469
🧹 Nitpick comments (10)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (2)
127-170: Unfused benchmark behavior vs MXFP4/global_scale contractThe combined unfused path (add → RMSNorm →
fp4_quantize) is structured correctly and includesglobal_scalefor NVFP4. However, the helper and PR description state thatglobal_scaleshould not be used for MXFP4, while this function would still forward a non-Noneglobal_scaleif called withblock_size=32.Consider explicitly gating this:
- Only pass
global_scaletofp4_quantizewhenblock_size == 16, or- Force
global_scale=Nonewhenblock_size == 32.This keeps the helper aligned with the documented MXFP4 behavior and avoids surprises if it’s reused with
block_size=32.
261-387: run_benchmark orchestration and reporting look good; consider refining exception handlingThe benchmark wiring for:
- Fixed NVFP4
block_size=16and calibratedglobal_scale,- Fused vs unfused timing and bandwidth computation, and
- Geomean speedup vs the unfused path
is coherent and matches the rmsnorm-only benchmark style.
The broad
except Exception as e/except Exceptionblocks are acceptable for a CLI benchmark, but they do trip Ruff’s BLE001 and can hide unexpected errors.If you want to align with the linter while keeping robustness, consider:
- Catching a narrower set (e.g.,
RuntimeError) or- At least logging the exception type/message more prominently so unexpected failures are obvious.
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (2)
137-179: Unfused RMSNorm+FP4 helper: clarify MXFP4/global_scale behavior
bench_separate_flashinfercorrectly sequences RMSNorm thenfp4_quantizeand times the combined operation. The doc comment says that for MXFP4 (block_size=32)global_scaleis not used, but the function still forwards theglobal_scaleargument tofp4_quantizeunconditionally.To keep behavior and docs in sync, consider:
- Only passing
global_scalewhenblock_size == 16(NVFP4), or- Passing
global_scale=Nonein the MXFP4 branch.This also matches the PR’s guidance that
global_scaleshould not be provided for MXFP4.
280-391: Benchmark harness updates and speedup reporting are coherent; optional refinement to exception handlingThe changes to:
- Fix
block_size=16(NVFP4),- Introduce a calibrated
global_scalefor benchmarking,- Report fused time, bandwidth, unfused time, and speedup, and
- Compute geomean speedup vs the unfused path
are internally consistent and mirror the Add+RMSNorm benchmark.
As in the other benchmark file, the bare
except Exceptionhandlers flagged by Ruff are acceptable for a benchmarking script but can obscure unexpected failures. Narrowing the exception type or improving the logged diagnostics would be a low-cost improvement.tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (4)
42-83: Helper functions for dequantization, global_scale, and tiered tolerance look correct
dequantize_fp4_outputcorrectly:
- Interprets
torch.float4_e2m1fn_x2asuint8forcast_from_fp4,- Applies per-block scales for both E4M3 (
float8_e4m3fn) and UE8M0 (uint8via2^(ue8m0-127)), and- Optionally divides by
global_scaleto conceptually undo the fused scaling.
compute_global_scalematches the benchmark-style formula and uses a reference LLaMA RMSNorm; for test-only use this is fine, though guarding againsttensor_amax == 0would make it more robust.
assert_close_with_tiered_toleranceis a good fit for low-precision FP4 comparisons, capturing both “most values tight” and “all values bounded” constraints.Overall, these helpers provide a solid foundation for the new global_scale-aware tests.
Also applies to: 85-108, 110-148
25-29: Use shared flashinfer.utils helpers for GPU capability checksThe tests currently use a local
get_cc()wrapper overtorch.cuda.get_device_capability()and custom skip conditions (requires_blackwell).Per the testing guidelines, these checks should ideally go through
flashinfer.utilshelpers such asget_compute_capability/is_sm100a_supportedto keep skip logic consistent across the suite.Consider refactoring
requires_blackwell()(and any direct CC checks) to delegate to the shared utilities instead of duplicating capability logic here.Also applies to: 156-168
404-469: Fused vs separate FP4Quant tests are well-designed; minor cleanup around global_scale usageThe new
TestFusedVsSeparateFP4Quantizetests:
- Compare fused Add+RMSNorm+FP4Quant against
add + RMSNorm + fp4_quantizefor both NVFP4 (block_size=16, E4M3) and MXFP4 (block_size=32, UE8M0).- Check:
- Packed FP4 bytes (
view(torch.uint8)),- Block scale factors, and
- Dequantized values via
dequantize_fp4_outputand the tiered tolerance helper.This is a strong end-to-end validation that the fused kernels match the standalone
fp4_quantizeimplementation, including global_scale behavior.One small point: in the MXFP4 test you pass
global_scale_val = torch.tensor(1.0, ...)positionally intofp4_quantizeeven though MXFP4 conceptually doesn’t useglobal_scale. Keeping this at1.0is harmless, but if the underlying API ever tightens its contract for UE8M0, it may be safer to passglobal_scale=Noneexplicitly in that branch.Also applies to: 472-587, 589-683
685-745: Unused y_fp4_gs / y_fp4_no_gs in global_scale consistency testIn
test_global_scale_value_consistency, the unpacked variables:
y_fp4_gs(line 708),y_fp4_no_gs(line 721),are never used; only the corresponding block scales are consumed.
To satisfy the linter and clarify intent, you can either:
- Prefix them with an underscore (
_y_fp4_gs,_y_fp4_no_gs), or- Assign to
_if you don’t plan to use the outputs.tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (2)
25-28: Consider using flashinfer.utils functions for GPU capability checks.The custom
get_cc()function works but doesn't follow the coding guidelines. Per the guidelines and learnings, test implementations should useflashinfer.utilsfunctions likeget_compute_capability(),is_sm90a_supported(), oris_sm100a_supported()to skip tests on unsupported GPU architectures.🔎 Suggested refactor
+from flashinfer.utils import get_compute_capability + -def get_cc(): - """Get CUDA compute capability.""" - major, minor = torch.cuda.get_device_capability() - return major * 10 + minorThen update usages:
def requires_hopper_or_later(): """Check if running on Hopper (SM90+) or later GPU.""" - return get_cc() >= 90 + return get_compute_capability() >= 90 def requires_blackwell(): """Check if running on Blackwell GPU.""" - return get_cc() >= 100 + return get_compute_capability() >= 100Based on coding guidelines: Test implementations should use
flashinfer.utilsfunctions for GPU capability checks.
751-751: Use underscore prefix for intentionally unused variables.Lines 751 and 763 unpack return values but only use the
block_scalevariables. Per Python convention, use underscore prefix for intentionally unused variables to improve clarity and silence linter warnings.🔎 Proposed fix
- y_fp4_gs, block_scale_gs = rmsnorm_fp4quant( + _y_fp4_gs, block_scale_gs = rmsnorm_fp4quant( x, weight, global_scale=global_scale, eps=eps, block_size=block_size, is_sf_swizzled_layout=False, ) - y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant( + _y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant( x, weight, global_scale=global_scale_one, eps=eps, block_size=block_size, is_sf_swizzled_layout=False, )Also applies to: 763-763
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.pybenchmarks/bench_cute_dsl_rmsnorm_fp4quant.pyflashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.pytests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.pytests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧰 Additional context used
📓 Path-based instructions (2)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/cute_dsl/rmsnorm_fp4quant.pyflashinfer/cute_dsl/add_rmsnorm_fp4quant.py
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.pytests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧠 Learnings (4)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation
Applied to files:
flashinfer/cute_dsl/rmsnorm_fp4quant.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Applied to files:
flashinfer/cute_dsl/rmsnorm_fp4quant.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
flashinfer/cute_dsl/rmsnorm_fp4quant.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.pytests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧬 Code graph analysis (2)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (3)
unfused_operation(156-166)sanity_check_outputs(181-265)run_benchmark(268-392)flashinfer/norm.py (1)
rmsnorm(33-68)flashinfer/testing/utils.py (1)
bench_gpu_time(1484-1631)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (6)
kernel(1091-1728)fmin_f32(208-220)cvt_f32_to_e4m3(462-482)fp8_e4m3_to_f32_and_rcp(486-518)ue8m0_to_output_scale(574-606)get_cute_pointers(1751-1792)flashinfer/cute_dsl/utils.py (1)
make_ptr(175-223)
🪛 Ruff (0.14.10)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py
348-348: Do not catch blind exception: Exception
(BLE001)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
76-76: Avoid specifying long messages outside the exception class
(TRY003)
145-148: Avoid specifying long messages outside the exception class
(TRY003)
708-708: Unpacked variable y_fp4_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
721-721: Unpacked variable y_fp4_no_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py
342-342: Do not catch blind exception: Exception
(BLE001)
355-355: Do not catch blind exception: Exception
(BLE001)
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
80-80: Avoid specifying long messages outside the exception class
(TRY003)
151-154: Avoid specifying long messages outside the exception class
(TRY003)
751-751: Unpacked variable y_fp4_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
763-763: Unpacked variable y_fp4_no_gs is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (14)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (2)
76-89: Fused CuTe-DSL benchmark: global_scale threading and new FP4 dtype look correctThe added
global_scaleparameter is cleanly threaded intoadd_rmsnorm_fp4quant, and switchingy_fp4totorch.float4_e2m1fn_x2matches the fused kernel’s new output dtype. The allocation shapes and use ofbench_gpu_timeremain consistent with the bandwidth model (1 byte per packed FP4 pair).Also applies to: 105-122
172-257: Sanity check for fused vs separate path with global_scale is well-structuredThe sanity check correctly:
- Uses the new
torch.float4_e2m1fn_x2dtype for FP4 outputs.- Propagates
global_scalethrough both fused and separate paths.- Compares packed FP4 bytes via
.view(torch.uint8)with a reasonable ≥70% match threshold.This should be a solid guard against regressions in the fused kernel’s global_scale handling.
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (2)
87-97: Fused RMSNorm benchmark: global_scale and FP4 dtype integration LGTMAdding
global_scaletobench_cute_dsland allocatingy_fp4astorch.float4_e2m1fn_x2matches the fused kernel’s API. The block-scale allocations andscale_formatselection stay consistent with NVFP4 (E4M3) vs MXFP4 (UE8M0).Also applies to: 114-125
181-265: Sanity check with global_scale and new FP4 dtype is soundThe updated
sanity_check_outputs:
- Uses
torch.float4_e2m1fn_x2and the updated fused API,- Propagates
global_scalethrough fused and separate paths, and- Compares packed FP4 results via
.view(torch.uint8)with a ≥70% match threshold.This is a reasonable and robust check for fused-vs-separate behavior under global scaling.
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2181-2215: Pointer wiring for global_scale looks correctThe additions to
_get_compiled_kernelandtensor_apito handle aglobal_scalepointer are consistent with the tensor layout used in the kernel:
- Dummy and real pointer lists both include a final
cutlass.Float32gmem pointer for the scalar global scale.tensor_apialways receives aglobal_scaletensor and passes it into the compiled kernel invocation.This matches the kernel’s new parameter list without changing the call sites’ responsibilities.
Also applies to: 2216-2239, 2259-2277
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (4)
197-215: Core 2D/3D/NVFP4/MXFP4 tests updated to new FP4 dtype and look solidAcross the main test classes:
- All
y_fp4tensors now usetorch.float4_e2m1fn_x2with the expected(batch, hidden_size // 2)or(batch, seq_len, hidden_size // 2)shapes.block_scaledtypes and shapes are consistent:
- E4M3/NVFP4:
torch.float8_e4m3fnwith(batch, ..., hidden_size // block_size),- UE8M0/MXFP4:
torch.uint8with matching shapes.- Dequantization checks use either plain
torch.testing.assert_closeor the tiered helper, with tolerances aligned to FP4 precision and MXFP4’s extra quantization error.These baseline correctness tests align well with the new kernel outputs and the rest of the PR.
Also applies to: 252-258, 315-323, 364-371, 775-795, 828-833
758-866: Large hidden-size tests with new FP4 dtype remain consistentThe large hidden-size NVFP4/MXFP4 tests:
- Correctly use
torch.float4_e2m1fn_x2fory_fp4,- Preserve expected
block_scaledtypes (E4M3 vs UE8M0),- Only sample a subset of rows for dequantization to keep runtime manageable.
Given the problem sizes, this strikes a good balance between coverage and test cost.
905-983: Swizzled vs unswizzled tests adapt cleanly to float4_e2m1fn_x2The swizzled scale-factor tests:
- Allocate both reference and swizzled
y_fp4astorch.float4_e2m1fn_x2,- Compare FP4 outputs via
.view(torch.uint8), and- Use
unswizzle_sfto bring swizzled scales back to row-major for equality checks.This is a thorough check that the new swizzled layout still matches the unswizzled baseline under the updated dtype.
Also applies to: 985-1053
1056-1284: Auto-allocation tests comprehensively cover NVFP4/MXFP4 and swizzled layoutsThe
TestAutoAllocationclass:
- Verifies that omitting
y_fp4andblock_scalereturns correctly-shaped and correctly-typed tensors for:
- 2D/3D NVFP4,
- MXFP4, and
- NVFP4 with swizzled scale layout.
- Confirms numerical correctness against the LLaMA RMSNorm reference and equality vs preallocated outputs (bitwise via
.view(torch.uint8)).This is excellent coverage of the new allocation semantics and should catch most regressions in the Python wrapper behavior.
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (5)
42-89: LGTM!The global_scale integration in dequantization is correct. The function properly reverses the scaling applied during quantization by dividing the result by
global_scale.item()when provided.
92-113: LGTM!The global_scale computation correctly implements the formula to ensure the dynamic range fits within FP4. Constants and device placement are appropriate.
116-154: LGTM!The two-tiered tolerance check is well-designed for quantized outputs. The detailed error messages are valuable for debugging quantization mismatches, despite the static analysis warning about message length.
518-787: LGTM!The new test classes provide excellent coverage:
TestFusedVsSeparateFP4Quantizevalidates consistency between fused and separate quantization paths for both NVFP4 and MXFP4test_global_scale_value_consistencyverifies that global_scale correctly scales the block scales- Tests use appropriate tolerance checks and cover multiple parameter combinations
1097-1346: LGTM!The
TestAutoAllocationclass provides comprehensive coverage of the auto-allocation feature:
- Tests both 2D and 3D inputs with NVFP4 (including global_scale)
- Tests MXFP4 format with UE8M0 scales
- Tests swizzled layout auto-allocation
- Verifies auto-allocated results match pre-allocated results
- Proper shape, dtype, and value assertions throughout
📌 Description
This PR enhances the
rmsnorm_fp4quantandadd_rmsnorm_fp4quantCuTe-DSL kernels with two key improvements:y_fp4andblock_scaleoutputs can now be either provided for in-place update or omitted for automatic allocation and returnglobal_scaletensor (torch.Tensor | None, shape [1], dtype float32) for NVFP4 quantization, enabling proper dynamic range scaling when global_scale is pre-computed. Should not be provided for mxfp4File Changes:
rmsnorm_fp4quant.py/add_rmsnorm_fp4quant.py: Added global_scale: torch.Tensor | None = None parameter; kernel now reads global scale from device memory and incorporates it into block scale computationbench_cute_dsl_rmsnorm_fp4quant.py/bench_cute_dsl_add_rmsnorm_fp4quant.py: Updated unfused baseline to measure time for (add +) rmsnorm + fp4 quant, instead of measuring separately.test_rmsnorm_fp4_quant_cute_dsl.py/test_add_rmsnorm_fp4_quant_cute_dsl.py: Added auto-allocation tests, global scale verification tests, and fused-vs-separate comparison tests.API Changes:
B200 (SM100) Benchmarks
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Benchmark Improvements
Testing
✏️ Tip: You can customize this high-level summary in your review settings.