Skip to content

fix: Add global scale support and optional output allocation for RMSNorm+FP4Quant fusion kernels#2260

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
bkryu:rmsnorm_fusion_global_sf
Jan 1, 2026
Merged

fix: Add global scale support and optional output allocation for RMSNorm+FP4Quant fusion kernels#2260
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
bkryu:rmsnorm_fusion_global_sf

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Dec 23, 2025

📌 Description

This PR enhances the rmsnorm_fp4quant and add_rmsnorm_fp4quant CuTe-DSL kernels with two key improvements:

  • Optional output allocation: y_fp4 and block_scale outputs can now be either provided for in-place update or omitted for automatic allocation and return
  • Global scale support: Both fusion patterns now accept an optional global_scale tensor (torch.Tensor | None, shape [1], dtype float32) for NVFP4 quantization, enabling proper dynamic range scaling when global_scale is pre-computed. Should not be provided for mxfp4

File Changes:

  • rmsnorm_fp4quant.py / add_rmsnorm_fp4quant.py: Added global_scale: torch.Tensor | None = None parameter; kernel now reads global scale from device memory and incorporates it into block scale computation
  • bench_cute_dsl_rmsnorm_fp4quant.py / bench_cute_dsl_add_rmsnorm_fp4quant.py: Updated unfused baseline to measure time for (add +) rmsnorm + fp4 quant, instead of measuring separately.
  • test_rmsnorm_fp4_quant_cute_dsl.py / test_add_rmsnorm_fp4_quant_cute_dsl.py: Added auto-allocation tests, global scale verification tests, and fused-vs-separate comparison tests.

API Changes:

# Before: outputs required
rmsnorm_fp4quant(x, weight, y_fp4, block_scale, ...)

# After: outputs optional, global_scale supported
y_fp4, block_scale = rmsnorm_fp4quant(x, weight, global_scale=gs, ...)  # auto-allocate
rmsnorm_fp4quant(x, weight, y_fp4, block_scale, global_scale=gs, ...)   # in-place
B200 (SM100) Benchmarks
$ python3 bench_cute_dsl_rmsnorm_fp4quant.py 
================================================================================
Fused RMSNorm + FP4 Quantization Benchmark
================================================================================
GPU Compute Capability: SM100

Running sanity check...
  OK: (128, 256) - FP4 match 99.8%
  OK: (512, 1024) - FP4 match 99.8%
  OK: (1024, 2048) - FP4 match 99.8%
✓ Confirmed: CuTe-DSL output is equivalent to RMSNorm + fp4_quantize


Batch    Hidden   Fused (µs)   BW (GB/s)  Unfused (µs)   Speedup   
-------------------------------------------------------------------
1000     1536     4.4          898.5      6.8            1.54x     
1000     2048     5.2          1019.4     7.4            1.43x     
1000     4096     6.7          1563.1     12.1           1.80x     
1000     8192     9.2          2291.5     20.2           2.20x     
1000     16384    22.1         1897.4     31.5           1.42x     
1000     32768    31.6         2663.3     52.0           1.65x     
1024     1536     4.4          920.1      6.8            1.55x     
1024     2048     5.1          1050.4     7.4            1.44x     
1024     4096     6.8          1593.1     12.2           1.80x     
1024     8192     9.2          2342.4     20.3           2.21x     
1024     16384    22.9         1880.4     31.8           1.39x     
1024     32768    31.9         2697.1     51.9           1.63x     
2048     1536     5.5          1465.1     9.9            1.80x     
2048     2048     6.5          1663.4     11.6           1.80x     
2048     4096     9.1          2357.9     20.1           2.20x     
2048     8192     16.8         2562.4     34.6           2.06x     
2048     16384    36.5         2357.9     57.3           1.57x     
2048     32768    53.5         3217.2     94.1           1.76x     
3000     1536     6.5          1818.2     12.7           1.96x     
3000     2048     7.7          2033.6     15.2           1.97x     
3000     4096     12.3         2563.2     26.9           2.19x     
3000     8192     22.4         2816.3     50.4           2.25x     
3000     16384    49.0         2569.9     83.1           1.70x     
3000     32768    73.2         3443.0     130.5          1.78x     
4096     1536     7.5          2153.4     15.4           2.05x     
4096     2048     8.8          2434.3     19.3           2.19x     
4096     4096     16.5         2606.7     35.4           2.14x     
4096     8192     29.2         2943.6     66.8           2.29x     
4096     16384    61.3         2803.8     109.1          1.78x     
4096     32768    95.8         3591.7     173.8          1.81x     
5000     1536     8.5          2312.4     18.2           2.14x     
5000     2048     10.4         2531.3     22.9           2.21x     
5000     4096     18.7         2803.9     42.3           2.26x     
5000     8192     35.2         2982.3     80.0           2.27x     
5000     16384    72.7         2889.0     130.0          1.79x     
5000     32768    114.1        3680.8     206.1          1.81x     
8192     1536     11.6         2776.2     27.1           2.33x     
8192     2048     15.6         2747.7     34.3           2.19x     
8192     4096     28.6         3002.4     67.6           2.36x     
8192     8192     52.4         3279.1     127.2          2.42x     
8192     16384    113.9        3021.1     209.4          1.84x     
8192     32768    178.5        3854.4     332.1          1.86x     
10000    1536     14.1         2783.0     31.6           2.23x     
10000    2048     17.8         2944.7     40.3           2.26x     
10000    4096     34.5         3038.7     81.3           2.35x     
10000    8192     62.1         3380.8     153.1          2.46x     
10000    16384    135.2        3106.7     252.2          1.87x     
10000    32768    214.7        3911.2     401.1          1.87x     
15000    1536     19.4         3044.7     45.8           2.36x     
15000    2048     25.2         3126.0     59.7           2.37x     
15000    4096     47.4         3322.2     118.1          2.49x     
15000    8192     89.0         3539.8     224.8          2.53x     
15000    16384    192.3        3274.4     373.5          1.94x     
15000    32768    315.1        3997.2     592.1          1.88x     
16384    1536     20.9         3086.3     50.2           2.40x     
16384    2048     27.2         3165.0     64.8           2.39x     
16384    4096     51.0         3371.5     128.2          2.51x     
16384    8192     96.3         3570.9     245.7          2.55x     
16384    16384    210.2        3272.7     407.1          1.94x     
16384    32768    342.7        4014.3     646.6          1.89x     
25000    1536     30.4         3231.8     75.1           2.47x     
25000    2048     38.7         3392.7     96.8           2.50x     
25000    4096     73.0         3596.6     191.8          2.63x     
25000    8192     142.4        3686.3     369.4          2.59x     
25000    16384    310.0        3386.3     614.6          1.98x     
25000    32768    515.6        4071.7     976.8          1.89x     
32768    1536     38.2         3378.5     96.8           2.53x     
32768    2048     48.2         3568.4     124.3          2.58x     
32768    4096     92.8         3705.0     249.0          2.68x     
32768    8192     184.0        3739.5     482.0          2.62x     
32768    16384    401.8        3424.1     799.3          1.99x     
32768    32768    672.9        4088.8     1312.0         1.95x     
60000    1536     64.1         3682.7     171.8          2.68x     
60000    2048     81.5         3863.4     222.0          2.72x     
60000    4096     162.3        3880.2     449.3          2.77x     
60000    8192     329.5        3822.1     873.7          2.65x     
60000    16384    719.2        3502.5     1458.1         2.03x     
60000    32768    1265.2       3982.2     2440.1         1.93x     
65536    1536     69.3         3723.3     187.5          2.71x     
65536    2048     88.3         3895.6     242.6          2.75x     
65536    4096     176.5        3896.3     489.2          2.77x     
65536    8192     359.2        3830.4     953.7          2.66x     
65536    16384    783.9        3510.1     1590.3         2.03x     
65536    32768    1341.8       4101.3     2705.2         2.02x     

================================================================================
Geomean speedup vs Unfused (rmsnorm + fp4_quantize): 2.10x
================================================================================
Benchmark Complete
================================================================================

$ python3 bench_cute_dsl_add_rmsnorm_fp4quant.py
================================================================================
Fused Add + RMSNorm + FP4 Quantization Benchmark
================================================================================
GPU Compute Capability: SM100

Running sanity check...
  OK: (128, 256) - FP4 match 99.9%
  OK: (512, 1024) - FP4 match 99.9%
  OK: (1024, 2048) - FP4 match 99.9%
✓ Confirmed: CuTe-DSL output is equivalent to torch.add + RMSNorm + fp4_quantize


Batch    Hidden   Fused (µs)   BW (GB/s)  Unfused (µs)   Speedup   
-------------------------------------------------------------------
1000     1536     5.0          1413.5     9.7            1.96x     
1000     2048     5.5          1708.4     10.7           1.95x     
1000     4096     8.9          2094.1     16.4           1.84x     
1000     8192     13.1         2864.0     27.5           2.11x     
1000     16384    33.5         2232.1     44.4           1.33x     
1000     32768    66.7         2243.9     83.8           1.26x     
1024     1536     5.0          1438.2     9.8            1.96x     
1024     2048     5.5          1729.1     10.8           1.95x     
1024     4096     9.0          2121.5     16.6           1.84x     
1024     8192     13.2         2890.1     27.8           2.10x     
1024     16384    34.5         2220.9     45.0           1.31x     
1024     32768    67.4         2272.1     85.5           1.27x     
2048     1536     7.1          2020.8     13.9           1.96x     
2048     2048     8.7          2211.3     16.6           1.92x     
2048     4096     13.1         2928.4     27.4           2.09x     
2048     8192     22.2         3447.5     49.0           2.20x     
2048     16384    61.7         2481.9     89.9           1.46x     
2048     32768    121.2        2525.8     155.4          1.28x     
3000     1536     9.9          2130.1     18.0           1.82x     
3000     2048     10.6         2638.8     21.4           2.02x     
3000     4096     17.1         3275.2     38.1           2.23x     
3000     8192     30.5         3675.4     73.7           2.42x     
3000     16384    86.7         2587.8     128.4          1.48x     
3000     32768    170.0        2639.4     218.1          1.28x     
4096     1536     11.2         2555.9     22.0           1.96x     
4096     2048     12.5         3067.1     27.2           2.18x     
4096     4096     22.1         3462.1     49.6           2.24x     
4096     8192     39.1         3915.4     98.9           2.53x     
4096     16384    115.7        2646.4     170.0          1.47x     
4096     32768    224.7        2725.9     291.4          1.30x     
5000     1536     13.5         2598.1     25.8           1.91x     
5000     2048     14.6         3209.1     32.2           2.21x     
5000     4096     25.9         3609.7     60.3           2.33x     
5000     8192     45.9         4068.6     118.5          2.58x     
5000     16384    137.7        2714.0     202.8          1.47x     
5000     32768    269.2        2777.2     349.1          1.30x     
8192     1536     19.7         2917.3     38.9           1.97x     
8192     2048     20.8         3680.3     49.4           2.38x     
8192     4096     38.8         3941.0     100.4          2.58x     
8192     8192     70.5         4343.5     188.2          2.67x     
8192     16384    220.1        2782.6     326.6          1.48x     
8192     32768    427.1        2867.7     563.7          1.32x     
10000    1536     23.3         3004.2     45.4           1.95x     
10000    2048     24.5         3819.6     59.4           2.43x     
10000    4096     45.4         4112.9     120.5          2.65x     
10000    8192     84.3         4432.8     226.4          2.69x     
10000    16384    267.8        2791.0     393.8          1.47x     
10000    32768    517.6        2888.4     683.8          1.32x     
15000    1536     33.2         3167.9     67.5           2.03x     
15000    2048     34.3         4085.9     90.2           2.63x     
15000    4096     64.7         4334.6     174.9          2.70x     
15000    8192     122.2        4587.7     333.1          2.73x     
15000    16384    397.2        2823.4     582.8          1.47x     
15000    32768    766.5        2925.8     1014.2         1.32x     
16384    1536     36.0         3192.3     74.9           2.08x     
16384    2048     36.9         4145.8     98.1           2.66x     
16384    4096     69.9         4379.2     189.6          2.71x     
16384    8192     132.8        4609.6     363.1          2.73x     
16384    16384    433.0        2828.8     635.6          1.47x     
16384    32768    837.1        2926.2     1113.5         1.33x     
25000    1536     51.3         3417.7     112.4          2.19x     
25000    2048     52.0         4496.5     145.2          2.79x     
25000    4096     102.8        4546.2     283.1          2.75x     
25000    8192     197.7        4725.8     547.1          2.77x     
25000    16384    653.6        2859.5     962.6          1.47x     
25000    32768    1266.7       2950.7     1726.5         1.36x     
32768    1536     64.7         3547.3     144.4          2.23x     
32768    2048     66.0         4639.2     186.7          2.83x     
32768    4096     132.4        4625.2     367.1          2.77x     
32768    8192     256.2        4779.7     713.6          2.78x     
32768    16384    856.9        2858.6     1259.6         1.47x     
32768    32768    1652.6       2964.4     2267.0         1.37x     
60000    1536     112.7        3729.8     255.0          2.26x     
60000    2048     115.0        4876.2     331.3          2.88x     
60000    4096     235.2        4767.4     662.0          2.81x     
60000    8192     462.3        4851.2     1294.6         2.80x     
60000    16384    1560.6       2873.9     2311.2         1.48x     
60000    32768    3008.9       2981.2     4225.7         1.40x     
65536    1536     122.4        3751.8     277.6          2.27x     
65536    2048     124.9        4901.8     361.0          2.89x     
65536    4096     256.2        4780.9     721.2          2.82x     
65536    8192     503.5        4864.7     1412.8         2.81x     
65536    16384    1703.0       2876.7     2508.2         1.47x     
65536    32768    3288.8       2979.2     4617.6         1.40x     

================================================================================
Geomean speedup vs Unfused (add + rmsnorm + fp4_quantize): 1.96x
================================================================================
Benchmark Complete
================================================================================

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added optional global scaling for FP4 quantization; quantization APIs now return quantized output plus block scales and support auto-allocation.
  • Benchmark Improvements

    • Benchmarks now propagate global_scale, report fused vs unfused timings, and show a single speedup metric versus the unfused path with simplified output formatting.
  • Testing

    • Expanded tests to cover global-scale paths, auto-allocation, swizzled layouts, large sizes, and introduced two-tier tolerance assertions.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 23, 2025

📝 Walkthrough

Walkthrough

Adds an optional global_scale parameter across fused CuTe-DSL Add+RMSNorm+FP4 and RMSNorm+FP4 kernels, threads it through host→device pointer bindings and kernels, updates public APIs to accept/return global_scale and (y_fp4, block_scale), refactors benchmarks to report FUSED vs UNFUSED timings and speedup, and extends tests with global_scale-aware checks and two-tier tolerances.

Changes

Cohort / File(s) Summary
Benchmark Refactoring
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py, benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py
Thread global_scale through benchmark entry points; add bench_unfused; unify unfused timing into a single metric; run_benchmark uses a fixed global_scale; output now shows FUSED time, UNFUSED time, and speedup against unfused.
Kernel: Add+RMSNorm+FP4Quant
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
Add global_scale_ptr to host/kernel signatures and pointer API; host builds 1-element global_scale tensor; device reads global_scale and applies it into per-block scale/inv_scale math for FP4 paths; public API accepts optional global_scale and returns (y_fp4, block_scale).
Kernel: RMSNorm+FP4Quant
flashinfer/cute_dsl/rmsnorm_fp4quant.py
Add global_scale_ptr to host/kernel call paths and tensor_api/get_cute_pointers; kernel incorporates global_scale into block-scale and inv_scale computations; public API accepts optional global_scale and returns (y_fp4, block_scale).
Tests & Utilities: Add+RMSNorm+FP4Quant
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
Extend dequantize_fp4_output to accept global_scale; add compute_global_scale() and assert_close_with_tiered_tolerance(); add tests covering auto-allocation, fused vs unfused paths, swizzled layouts, large sizes, and global_scale consistency.
Tests & Utilities: RMSNorm+FP4Quant
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
Same test utility additions and updates: dequantize_fp4_output(global_scale), compute_global_scale(), two-tier tolerance checks, and expanded coverage for global_scale across NVFP4/MXFP4 formats, swizzled layouts, and auto-allocation.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Bench as Benchmark Runner
  participant Host as Python Host API
  participant Device as CUDA Kernel
  participant Mem as Device Memory (y_fp4, block_scale)

  Note right of Bench: benchmark uses fixed global_scale
  Bench->>Host: call add_rmsnorm_fp4quant(x,r,w,..., global_scale)
  Host->>Device: bind pointers (x,r,w,y,s,global_scale) & launch kernel
  Device->>Mem: read global_scale, compute per-block scale & inv_scale
  Device->>Mem: write y_fp4 and block_scale
  Device-->>Host: kernel returns (y_fp4_ptr, block_scale_ptr)
  Host-->>Bench: return (y_fp4, block_scale) -> record FUSED time
  Note left of Bench: UNFUSED path runs Add -> RMSNorm -> FP4Quant with same global_scale
  Bench->>Bench: compute speedup = unfused_us / fused_us
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • kaixih
  • aleozlx
  • cyx-6
  • kahyunnam
  • jiahanc

"I hopped through kernels, scales in tow,
Fused and unfused both ready to show,
Blocks find their range, bytes fall in line,
Tests give me carrots, timing smells fine.
— a happy benchmarking rabbit 🐇"

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the two main enhancements (global scale support and optional output allocation) for the specific RMSNorm+FP4Quant fusion kernels.
Description check ✅ Passed The description covers the main changes, affected files, API modifications, and includes benchmark results. However, the test checklist shows unchecked items indicating tests may not all be passing.
Docstring Coverage ✅ Passed Docstring coverage is 96.97% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the rmsnorm_fp4quant and add_rmsnorm_fp4quant CuTe-DSL kernels by introducing greater flexibility in output handling and improving quantization accuracy. The new optional output allocation streamlines API usage, while the addition of global scale support for NVFP4 quantization allows for more precise dynamic range scaling, which is critical for maintaining model performance with reduced precision. These changes aim to make the kernels more robust, user-friendly, and performant for advanced quantization techniques.

Highlights

  • Optional Output Allocation: The rmsnorm_fp4quant and add_rmsnorm_fp4quant CuTe-DSL kernels now support optional output allocation. Users can either provide pre-allocated y_fp4 and block_scale tensors for in-place updates or omit them to have the kernels automatically allocate and return the outputs.
  • Global Scale Support: Both rmsnorm_fp4quant and add_rmsnorm_fp4quant kernels now accept an optional global_scale tensor. This feature is specifically designed for NVFP4 quantization, allowing for proper dynamic range scaling when a global scale is pre-computed, which helps improve quantization quality.
  • Benchmarking Improvements: The benchmarking scripts have been refactored to simplify comparisons, focusing on the overall performance of the fused operation against a combined unfused baseline, and incorporating the new global_scale parameter.
  • Comprehensive Testing: Extensive new test cases have been added to validate the optional output allocation, the correctness of global scale application for NVFP4, and to ensure consistency between fused and separate operations for both NVFP4 and MXFP4 quantization formats.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Dec 23, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !213 has been created, and the CI pipeline #40672783 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (1)

733-752: Prefix unused variables with underscore.

The unpacked variables y_fp4_gs and y_fp4_no_gs are intentionally unused since this test focuses on comparing block scale ratios. Prefixing with underscore clarifies intent and silences the linter.

🔎 Proposed fix
-        y_fp4_gs, block_scale_gs = rmsnorm_fp4quant(
+        _y_fp4_gs, block_scale_gs = rmsnorm_fp4quant(
             x,
             weight,
             global_scale=global_scale,
             eps=eps,
             block_size=block_size,
             is_sf_swizzled_layout=False,
         )

         # Run without global_scale (global_scale=1.0)
         global_scale_one = torch.tensor([1.0], dtype=torch.float32, device="cuda")

-        y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant(
+        _y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant(
             x,
             weight,
             global_scale=global_scale_one,
             eps=eps,
             block_size=block_size,
             is_sf_swizzled_layout=False,
         )
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (1)

690-711: Prefix unused variables with underscore.

Same as in the RMSNorm test file, the unpacked y_fp4_gs and y_fp4_no_gs variables are intentionally unused.

🔎 Proposed fix
-        y_fp4_gs, block_scale_gs = add_rmsnorm_fp4quant(
+        _y_fp4_gs, block_scale_gs = add_rmsnorm_fp4quant(
             x,
             r,
             weight,
             global_scale=global_scale,
             eps=eps,
             block_size=block_size,
             is_sf_swizzled_layout=False,
         )

         # Run without global_scale (global_scale=1.0)
         global_scale_one = torch.tensor([1.0], dtype=torch.float32, device="cuda")

-        y_fp4_no_gs, block_scale_no_gs = add_rmsnorm_fp4quant(
+        _y_fp4_no_gs, block_scale_no_gs = add_rmsnorm_fp4quant(
             x,
             r,
             weight,
             global_scale=global_scale_one,
             eps=eps,
             block_size=block_size,
             is_sf_swizzled_layout=False,
         )
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (3)

2286-2288: Add validation for global_scale tensor properties.

When global_scale is provided by the caller, there's no validation of its shape, dtype, or device. Invalid inputs could cause runtime errors or incorrect results.

Proposed validation

Add validation after line 2388:

     sm_version = get_sm_version(input.device)
 
+    # Validate global_scale if provided
+    if global_scale is not None:
+        assert global_scale.shape == (1,) or global_scale.numel() == 1, (
+            f"global_scale must have shape (1,), got {global_scale.shape}"
+        )
+        assert global_scale.dtype == torch.float32, (
+            f"global_scale must have dtype torch.float32, got {global_scale.dtype}"
+        )
+        assert global_scale.device == input.device, (
+            f"global_scale device {global_scale.device} must match input device {input.device}"
+        )
+        # Flatten to shape (1,) if needed
+        global_scale = global_scale.reshape(1)
+
     # Allocate output tensors if not provided

Also applies to: 2442-2443


1755-1772: Consider warning when global_scale is provided with UE8M0 format.

For UE8M0 (MXFP4, block_size=32), global_scale is silently ignored (lines 1755-1759, 2079-2083). Users might mistakenly provide global_scale expecting it to have an effect. Consider adding a validation or warning to make this explicit.

You could add a check in the public API after line 2386:

if global_scale is not None and actual_scale_format == "ue8m0":
    import warnings
    warnings.warn(
        "global_scale is only supported for E4M3 format and will be ignored for UE8M0 (MXFP4)",
        UserWarning
    )

Also applies to: 2079-2096


1377-1381: Clarify comment about "canceling global_scale".

The phrase "to cancel global_scale" at lines 1377 and 1559 may be confusing. The purpose is to ensure the quantized intermediate values are computed using standard quantization (without global_scale in the quantization step), while global_scale is retained in the stored block scale. Consider rewording for clarity.

Suggested comment clarification
-                            # inv_scale = global_scale / scale_float to cancel global_scale
+                            # inv_scale excludes global_scale from quantization computation
+                            # so q = y / (max_abs / FP4_MAX), while scale_fp8 = global_scale * max_abs / FP4_MAX
                             inv_scale = (
                                 fp8_e4m3_to_f32_and_rcp(scale_fp8_u32)
                                 * global_scale_val
                             )

Also applies to: 1559-1563

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25de38e and 1ba2ead.

📒 Files selected for processing (6)
  • benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py
  • benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
  • tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
  • tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (5)
  • kernel (1091-1728)
  • fmin_f32 (208-220)
  • cvt_f32_to_e4m3 (462-482)
  • fp8_e4m3_to_f32_and_rcp (486-518)
  • get_cute_pointers (1751-1792)
flashinfer/cute_dsl/utils.py (1)
  • make_ptr (175-223)
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (2)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (3)
  • compute_global_scale (83-105)
  • llama_rms_norm (31-39)
  • unswizzle_sf (851-884)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
  • rmsnorm_fp4quant (1836-2013)
🪛 Ruff (0.14.10)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py

74-74: Avoid specifying long messages outside the exception class

(TRY003)


143-146: Avoid specifying long messages outside the exception class

(TRY003)


690-690: Unpacked variable y_fp4_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


703-703: Unpacked variable y_fp4_no_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py

343-343: Do not catch blind exception: Exception

(BLE001)

benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py

337-337: Do not catch blind exception: Exception

(BLE001)


350-350: Do not catch blind exception: Exception

(BLE001)

tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py

78-78: Avoid specifying long messages outside the exception class

(TRY003)


149-152: Avoid specifying long messages outside the exception class

(TRY003)


733-733: Unpacked variable y_fp4_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


745-745: Unpacked variable y_fp4_no_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (4)

1836-1846: LGTM on the API signature update.

The function signature now correctly supports optional output allocation with y_fp4 and block_scale as optional parameters, and adds global_scale support. The return type is appropriately updated to Tuple[torch.Tensor, torch.Tensor].


1940-1993: Auto-allocation logic is well-implemented.

The logic correctly handles:

  • 2D vs 3D input shapes
  • Scale dtype selection based on format (UE8M0 → uint8, E4M3 → float8_e4m3fn)
  • Swizzled layout size calculation with 128x4 tile pattern
  • Default global_scale=1.0 when not provided

1252-1255: Verify global_scale read location for performance.

The global_scale_val is read from device memory inside the kernel loop. While this is correct for CUDA graph compatibility (as noted in the comment), reading it once per thread block rather than per-thread would be more efficient. However, the compiler likely optimizes this to a single load per warp.


1918-1925: LGTM on input reshaping.

The 2D/3D input handling is correct. The input_2d variable is properly used for kernel execution, and .contiguous() is called at the point of use (line 2004).

benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (3)

87-132: LGTM on global_scale integration in benchmarks.

The global_scale parameter is properly threaded through to the fused kernel call. The benchmark structure correctly measures the fused kernel performance with the new parameter.


135-176: Good refactoring of unfused benchmark.

Consolidating the unfused operations (rmsnorm + fp4_quantize) into a single timed function provides a more accurate comparison against the fused kernel. The global_scale is correctly passed to fp4_quantize for NVFP4.


337-354: Blind exception handling is acceptable here.

While static analysis flags except Exception, this pattern is appropriate in benchmark code to ensure the suite continues running even if individual configurations fail. The error message includes sufficient context (batch_size, hidden_size, and exception details).

tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (3)

90-111: LGTM on compute_global_scale helper.

The formula global_scale = (FP8_E4M3_MAX * FP4_E2M1_MAX) / max_abs(rmsnorm_output) correctly computes the optimal global scale to maximize dynamic range utilization for NVFP4 quantization.


114-152: Well-designed tiered tolerance check.

The two-tier tolerance approach appropriately handles quantization noise:

  • 99% of elements must match within tight tolerance (rtol=0.1, atol=0.1)
  • 100% of elements must match within loose tolerance (rtol=0.5, atol=2.0)

This is more robust than a single tolerance threshold for FP4 quantized outputs.


1075-1121: Excellent test coverage for auto-allocation.

The TestAutoAllocation class comprehensively covers:

  • 2D and 3D input shapes
  • NVFP4 (with global_scale) and MXFP4 formats
  • Swizzled layout auto-allocation
  • Equivalence between auto-allocated and pre-allocated paths

This ensures the new optional output allocation feature works correctly across all configurations.

benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)

76-122: LGTM on global_scale integration.

The global_scale parameter is correctly propagated to the fused add_rmsnorm_fp4quant kernel. The benchmark structure mirrors the RMSNorm-only benchmark file, maintaining consistency.


125-167: Good unfused benchmark implementation.

The unfused benchmark correctly:

  1. Pre-allocates intermediate tensors (h, y_normed) outside the timed region
  2. Times the combined add + rmsnorm + fp4_quantize workflow
  3. Passes global_scale to fp4_quantize for NVFP4 consistency

343-348: Blind exception handling is acceptable in benchmark code.

Similar to the other benchmark file, catching broad exceptions here ensures the benchmark suite continues running through all configurations. The error is logged with context.

tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (3)

83-105: LGTM on compute_global_scale for add+rmsnorm.

The function correctly computes global_scale based on the rmsnorm(x + residual, weight) output, which is the appropriate reference for the add+rmsnorm fusion.


466-575: Excellent test coverage for fused vs separate comparison.

The TestFusedVsSeparateFP4Quantize class thoroughly validates:

  1. FP4 packed output byte-level matching (>95%)
  2. Block scale factor matching (>95%)
  3. Dequantized value closeness with tiered tolerance
  4. Both NVFP4 and MXFP4 formats

This ensures the fused kernel produces results consistent with the separate implementation.


1034-1078: Good test coverage for auto-allocation.

The TestAutoAllocation class mirrors the RMSNorm test file structure, providing comprehensive coverage for the add+rmsnorm fusion kernel's auto-allocation feature.

Comment on lines +1109 to +1114
"""Device kernel with cluster sync and Half2 SIMD.

mGlobalScale contains the global scale value. The kernel reads it and
computes 1/global_scale, which is multiplied with rstd to apply:
y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Docstring inaccurately describes global_scale behavior.

The docstring states that the kernel "computes 1/global_scale, which is multiplied with rstd" and produces "y = rmsnorm(h, w) / global_scale". However, the implementation does not multiply rstd by 1/global_scale. Instead:

  1. RMSNorm is computed normally: y = h * rstd * w (lines 1335-1350)
  2. global_scale is incorporated into the stored block scale: scale_float = global_scale_val * max_abs / FP4_MAX (line 1371)
  3. Quantization uses inv_scale = FP4_MAX / max_abs which cancels out global_scale (lines 1378-1381)
  4. Net effect: dequantized output = rmsnorm(h, w) * global_scale (multiplies, not divides)

The docstring should clarify that global_scale multiplies the dequantized output by being incorporated into the stored block scale, rather than claiming it divides the RMSNorm output.

Proposed docstring correction
-        """Device kernel with cluster sync and Half2 SIMD.
-
-        mGlobalScale contains the global scale value. The kernel reads it and
-        computes 1/global_scale, which is multiplied with rstd to apply:
-        y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
-        """
+        """Device kernel with cluster sync and Half2 SIMD.
+
+        mGlobalScale contains the global scale value that is incorporated into
+        the block scale. RMSNorm is computed normally (y = h * rstd * w), but
+        the stored block scale is multiplied by global_scale. This results in
+        dequantized outputs being scaled: dequant_output = rmsnorm(h, w) * global_scale.
+        """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""Device kernel with cluster sync and Half2 SIMD.
mGlobalScale contains the global scale value. The kernel reads it and
computes 1/global_scale, which is multiplied with rstd to apply:
y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
"""
"""Device kernel with cluster sync and Half2 SIMD.
mGlobalScale contains the global scale value that is incorporated into
the block scale. RMSNorm is computed normally (y = h * rstd * w), but
the stored block scale is multiplied by global_scale. This results in
dequantized outputs being scaled: dequant_output = rmsnorm(h, w) * global_scale.
"""
🤖 Prompt for AI Agents
In flashinfer/cute_dsl/add_rmsnorm_fp4quant.py around lines 1109-1114, the
docstring incorrectly states the kernel computes 1/global_scale and divides the
RMSNorm output; instead, the implementation incorporates global_scale into the
stored block scale so the dequantized output is multiplied by global_scale.
Update the docstring to state that RMSNorm is computed normally (y = h * rstd *
w), that global_scale is factored into the stored block scale (scale_float =
global_scale_val * max_abs / FP4_MAX) and thus the dequantized result is
multiplied by global_scale, and remove the incorrect “1/global_scale” language;
reference the relevant implementation lines (≈1335-1350, 1371, 1378-1381) for
clarity.

Comment on lines +2328 to 2335
global_scale : torch.Tensor, optional
Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
If provided, the RMSNorm output is divided by this value before quantization:
``y = rmsnorm(h, w) / global_scale`` where ``h = input + residual``.
This is used for NVFP4 format where a pre-computed global scale lifts
per-block scales into optimal dynamic range.
If ``None``, no global scaling is applied (equivalent to global_scale=1.0).
eps : float
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Docstring incorrectly describes global_scale effect on quantization.

The docstring claims that "the RMSNorm output is divided by this value before quantization: y = rmsnorm(h, w) / global_scale". This is inaccurate. The implementation:

  1. Computes RMSNorm normally without applying global_scale to intermediate values
  2. Incorporates global_scale into the stored block scale
  3. Results in dequantized output = rmsnorm(h, w) * global_scale (multiply, not divide)

The parameter description should clarify that global_scale adjusts the magnitude of dequantized outputs by being baked into the block scale, and that larger values produce larger outputs (not smaller).

Proposed docstring correction
     global_scale : torch.Tensor, optional
         Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
-        If provided, the RMSNorm output is divided by this value before quantization:
-        ``y = rmsnorm(h, w) / global_scale`` where ``h = input + residual``.
-        This is used for NVFP4 format where a pre-computed global scale lifts
-        per-block scales into optimal dynamic range.
+        If provided, this value is incorporated into the per-block scales for E4M3 format.
+        The effect is to scale the dequantized output: ``dequant = rmsnorm(h, w) * global_scale``
+        where ``h = input + residual``. This adjusts the magnitude of outputs without affecting
+        quantization granularity. Only used for E4M3 format; ignored for UE8M0 (MXFP4).
         If ``None``, no global scaling is applied (equivalent to global_scale=1.0).
🤖 Prompt for AI Agents
In flashinfer/cute_dsl/add_rmsnorm_fp4quant.py around lines 2328 to 2335, the
docstring incorrectly states that the RMSNorm output is divided by global_scale
before quantization; instead, global_scale is incorporated into the stored block
scale so the dequantized output is multiplied by global_scale (dequantized =
rmsnorm(h, w) * global_scale). Update the parameter description to explain that
providing global_scale bakes that factor into the block scale and increases the
magnitude of dequantized outputs (larger global_scale -> larger outputs), and
clarify that None means no global scaling (equivalent to global_scale=1.0).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a high-quality contribution that enhances the rmsnorm_fp4quant and add_rmsnorm_fp4quant kernels with optional output allocation and support for a global scale factor. The implementation is robust, with the new features correctly integrated into the CuTe-DSL kernels, Python API, benchmarks, and tests.

The API changes make the functions more flexible by allowing automatic allocation of output tensors. The global scale support for NVFP4 quantization is correctly implemented by incorporating the scale into the block scale computation within the kernel, which is crucial for dynamic range management.

The benchmarks have been significantly improved by refactoring the unfused baseline to measure the entire pipeline, providing a more realistic performance comparison. The tests are exceptionally thorough, with new test classes validating auto-allocation, global scale correctness, and the equivalence between fused and separate execution paths. The introduction of a tiered tolerance assertion function is an excellent addition for robustly testing low-precision quantized outputs.

Overall, the changes are well-executed, well-tested, and improve both the functionality and usability of the fusion kernels. I have no specific comments on the code changes.

@bkryu bkryu self-assigned this Dec 23, 2025
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #40672783: 12/20 passed

@bkryu bkryu force-pushed the rmsnorm_fusion_global_sf branch from 1ba2ead to c660708 Compare December 30, 2025 18:43
@bkryu bkryu requested a review from Anerudhan as a code owner December 30, 2025 18:43
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (3)

1096-1114: Kernel docstring still describes incorrect 1/global_scale behavior.

Lines 1111–1114 claim the kernel “computes 1/global_scale, which is multiplied with rstd” and that y = rmsnorm(h, w) / global_scale. The implementation doesn’t modify rstd with 1/global_scale; it uses mGlobalScale only in E4M3 scale/inv_scale math, leaving RMSNorm unchanged and expecting dequantization to divide by global_scale.

Align this docstring with the implemented behavior: RMSNorm is computed normally, global_scale is folded into block scales for NVFP4/E4M3, and consumers should divide by global_scale when dequantizing.


2281-2335: Clarify public API global_scale semantics (baked into block scales, undone at dequant).

The global_scale parameter docstring still says “RMSNorm output is divided by this value before quantization: y = rmsnorm(h, w) / global_scale”, which doesn’t reflect the implementation:

  • The fused kernel computes RMSNorm normally.
  • For NVFP4/E4M3, global_scale is folded into the stored block scales and inv_scale.
  • Tests dequantize by multiplying FP4 bytes with the block scales and then dividing by global_scale to recover the RMSNorm output.

To avoid confusion, please update this section to something along the lines of:

  • global_scale is only used for E4M3 / NVFP4; it is ignored for UE8M0 (MXFP4).
  • It is incorporated into the per-block scales (scale ≈ global_scale * max_abs / FP4_MAX).
  • Downstream dequantization should divide by global_scale to reverse this factor; with global_scale=None (or 1.0), behavior matches prior semantics.

1016-1027: Update call docstring to match actual global_scale usage.

The kernel no longer multiplies rstd by 1/global_scale or computes y = rmsnorm(h) / global_scale; instead, global_scale is read into mGlobalScale and used only when forming E4M3 block scales and inv_scale. Conceptually, global_scale is baked into per-block scales, and downstream dequantization is expected to divide by global_scale to recover the RMSNorm output.

Please reword this docstring to describe:

  • RMSNorm computed normally (y = h * rstd * w),
  • global_scale incorporated into E4M3 block scales (scale ≈ global_scale * max_abs / FP4_MAX),
  • Dequantizers should divide by global_scale to undo this factor.
🧹 Nitpick comments (9)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)

1939-1992: Consider adding shape validation for user-provided output tensors.

The output allocation logic correctly handles all cases (2D/3D, swizzled/non-swizzled). However, when users provide pre-allocated y_fp4 or block_scale tensors, there's no validation to ensure they have the correct shapes and dtypes.

For example:

  • If y_fp4 has shape (batch_size, hidden_size) instead of (batch_size, hidden_size // 2), the kernel will write to incorrect memory locations
  • If block_scale is torch.uint8 but scale_format="e4m3", type mismatch will occur
🔎 Proposed validation logic

Add validation before kernel launch (around line 1939):

+    # Validate user-provided output tensors
+    if y_fp4 is not None:
+        expected_shape = (batch_size, hidden_size // 2) if not is_3d else (B, S, hidden_size // 2)
+        if y_fp4.shape != expected_shape:
+            raise ValueError(f"y_fp4 shape mismatch: expected {expected_shape}, got {y_fp4.shape}")
+        if y_fp4.dtype != torch.uint8:
+            raise ValueError(f"y_fp4 dtype must be torch.uint8, got {y_fp4.dtype}")
+
+    if block_scale is not None:
+        scale_dtype = torch.uint8 if actual_scale_format == "ue8m0" else torch.float8_e4m3fn
+        if block_scale.dtype != scale_dtype:
+            raise ValueError(f"block_scale dtype must be {scale_dtype} for {actual_scale_format}, got {block_scale.dtype}")
+        # Add shape validation based on layout mode
+
     # Allocate output tensors if not provided
     if y_fp4 is None:

1252-1254: Clarify global_scale behavior for MXFP4 (UE8M0) in docstring.

The implementation correctly incorporates global_scale for NVFP4 format (E4M3 scales) and intentionally ignores it for MXFP4 format (UE8M0 scales). However, the docstring is ambiguous about this limitation.

The global_scale parameter documentation states: "This is used for NVFP4 format where a pre-computed global scale lifts per-block scales into optimal dynamic range." This implies the limitation but doesn't explicitly state that global_scale is ignored for MXFP4 when block_size=32 or scale_format="ue8m0".

Update the global_scale parameter documentation to explicitly note: "Note: This parameter is only used for NVFP4 format (block_size=16, E4M3 scales). For MXFP4 format (block_size=32, UE8M0 scales), global_scale is ignored."

flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

2368-2465: Auto-allocation and 2D/3D/global_scale handling in Python API looks sound.

The reshaping to 2D, conditional allocation for 2D/3D and swizzled/unswizzled layouts, and default global_scale=torch.ones(1, ...) all line up with the kernel’s expectations and the new tests (auto-allocation, swizzled, NVFP4/MXFP4, global_scale consistency).

The only micro-optimization you might consider is avoiding redundant .contiguous() calls on already-contiguous views, but that’s optional and not performance-critical.

benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (1)

256-365: Narrow broad exception handling in benchmark loop.

The try/except Exception as e: blocks around fused and unfused timing will also swallow unexpected programming errors or misconfigurations, which can hide real issues.

Consider catching more specific exceptions (e.g., RuntimeError, torch.cuda.OutOfMemoryError) and either re-raising others or at least logging them distinctly.

tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (2)

42-147: Helper semantics are correct but duplicated across test modules.

dequantize_fp4_output, compute_global_scale, and assert_close_with_tiered_tolerance correctly model:

  • UE8M0 vs E4M3 scale decoding,
  • global_scale being baked into block scales and then undone by dividing at dequant,
  • and a two-tier tolerance regime appropriate for FP4 noise.

The same patterns appear in the RMSNorm FP4 test file; consider factoring these into a shared helper module (e.g., tests/test_helpers/fp4_quantization.py) to avoid divergence over time.


666-727: Tight global_scale consistency check is good; mark unused y_fp4 variables.

The test correctly asserts that block_scale with global_scale is ~global_scale times larger than without, in line with the formula scale = global_scale * max_abs / FP4_MAX.

Since y_fp4_gs and y_fp4_no_gs are unused, consider renaming them to _y_fp4_gs and _y_fp4_no_gs (or unpacking only the second element) to satisfy linters without changing behavior.

benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (1)

327-355: Consider narrowing broad exception handling in benchmark loop.

As in the add+rmsnorm benchmark, the bare except Exception blocks can hide unexpected programming errors.

It would be safer to catch expected runtime issues explicitly (e.g., CUDA OOM) and optionally re-raise others, instead of treating all failures as a generic “FUSED/UNFUSED ERROR”.

tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (2)

90-152: Global-scale computation and tiered tolerance helper are well-designed (but duplicated).

compute_global_scale and assert_close_with_tiered_tolerance mirror the helpers in the add+RMS tests and are appropriate for:

  • Choosing a global_scale that fits RMSNorm outputs into FP4 dynamic range.
  • Evaluating FP4 dequant results with a two-tier tolerance.

As noted in the add+RMS test file, consider extracting these shared helpers into a common test utility to reduce duplication.


711-768: Global-scale value consistency test is correct; mark unused y_fp4 variables.

The test correctly checks that block_scale with global_scale is ~global_scale times block_scale without, in accordance with the formula used inside the kernel.

Since y_fp4_gs and y_fp4_no_gs are not used, consider renaming them to _y_fp4_gs and _y_fp4_no_gs (or unpacking only the second return) to satisfy linters.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ba2ead and c660708.

📒 Files selected for processing (6)
  • benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py
  • benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
  • tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
  • tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
  • tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
🧬 Code graph analysis (4)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • hidden_size (265-265)
flashinfer/norm.py (1)
  • rmsnorm (33-68)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1484-1631)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
  • kernel (1096-2173)
  • get_cute_pointers (2193-2239)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
  • kernel (1091-1728)
  • get_cute_pointers (1751-1792)
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (3)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)
  • unfused_operation (143-155)
  • sanity_check_outputs (170-253)
  • compute_bandwidth_gb_s (37-73)
flashinfer/norm.py (1)
  • rmsnorm (33-68)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1484-1631)
🪛 Ruff (0.14.10)
tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py

74-74: Avoid specifying long messages outside the exception class

(TRY003)


143-146: Avoid specifying long messages outside the exception class

(TRY003)


690-690: Unpacked variable y_fp4_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


703-703: Unpacked variable y_fp4_no_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py

343-343: Do not catch blind exception: Exception

(BLE001)

tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py

78-78: Avoid specifying long messages outside the exception class

(TRY003)


149-152: Avoid specifying long messages outside the exception class

(TRY003)


733-733: Unpacked variable y_fp4_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


745-745: Unpacked variable y_fp4_no_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py

337-337: Do not catch blind exception: Exception

(BLE001)


350-350: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (21)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)

1010-1010: LGTM: Global scale parameter integration.

The addition of global_scale_ptr to the kernel signature and the creation of mGlobalScale tensor are correctly implemented. The parameter is properly threaded through the host and device functions.

Also applies to: 1063-1067, 1080-1080, 1097-1097


1768-1791: LGTM: Global scale pointer creation.

The pointer creation logic correctly handles both compilation (dummy pointer) and runtime (actual tensor) paths for global_scale. The alignment of 4 bytes for Float32 is appropriate.

flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

2218-2239: Global scale pointer wiring looks consistent with kernel expectations.

Adding global_scale as a Float32 GMEM pointer in get_cute_pointers and threading it into cute.compile and the compiled closure matches the new kernel signature. The dummy pointer path and assumed_align=4 are appropriate for the scalar.

benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (4)

76-123: Fused benchmark wiring with global_scale is correct.

Threading global_scale through to add_rmsnorm_fp4quant while keeping block_scale dtype consistent with block_size (E4M3 vs UE8M0) matches the new API and tests. Median-of-bench_gpu_time remains the right aggregation.


125-168: Unfused path correctly mirrors fused math and global_scale use.

The unfused torch.add + rmsnorm + fp4_quantize sequence, including global_scale for NVFP4 and sf_use_ue8m0=(block_size == 32), is aligned with how the fused kernel is exercised in tests. This looks like a valid baseline for speedup comparison.


170-253: Sanity-check path is consistent with fused/unfused semantics under global_scale.

Using the same global_scale for both fused and separate paths and comparing FP4 bytes with a relaxed percentage threshold is a good practical validation of correctness given FP4 noise and different operation ordering.


371-381: Geomean speedup reporting against unfused baseline looks good.

Collecting non-None speedups and reporting the geometric mean vs “unfused add + rmsnorm + fp4_quantize” matches the new fused-vs-unfused framing.

tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (3)

219-231: Tiered tolerance usage in value-level comparisons is appropriate.

Using assert_close_with_tiered_tolerance with tighter (rtol=0.3, atol=0.5) and looser (rtol=0.5, atol=2.0) thresholds for FP4 dequant results is a reasonable balance between strictness and the coarseness of 4‑bit quantization.


466-575: Fused vs separate NVFP4/MXFP4 comparisons robustly validate global_scale handling.

The new TestFusedVsSeparateFP4Quantize tests:

  • Compare packed FP4 bytes and block scales between fused and separate paths.
  • Use dequantization with optional global_scale and tiered tolerances.

This is exactly what’s needed to ensure the fused kernel applies global_scale identically to fp4_quantize for both NVFP4 (E4M3) and MXFP4 (UE8M0).


1034-1260: Auto-allocation tests thoroughly exercise new API shapes and layouts.

The TestAutoAllocation class verifies:

  • 2D/3D NVFP4, MXFP4, and swizzled layouts.
  • Correct shapes and dtypes of auto-allocated y_fp4 and block_scale.
  • Equality between preallocated and auto-allocated results.

These provide strong coverage for the new (y_fp4, block_scale) return semantics in add_rmsnorm_fp4quant.

benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (5)

87-133: CuTe-DSL fused benchmark correctly passes global_scale to rmsnorm_fp4quant.

The extended bench_cute_dsl signature and lambda correctly propagate global_scale to the fused RMSNorm+FP4 kernel, while choosing block_scale dtype and scale_format based on block_size. This aligns with the updated API and tests.


135-177: Unfused RMSNorm + fp4_quantize path matches fused semantics.

The unfused_operation (RMSNorm followed by fp4_quantize) and its use of global_scale for NVFP4, sf_use_ue8m0 for MXFP4, and bench_gpu_time over the combined op provide a solid baseline for fused speedups.


179-260: Sanity-check compares fused vs separate outputs under shared global_scale.

Running both rmsnorm_fp4quant and separate rmsnorm + fp4_quantize with the same global_scale and checking FP4 match percentage is a practical validation of the fused kernel’s correctness given FP4’s low precision.


278-287: Global_scale construction in benchmark is reasonable and consistent.

Using FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / 3.0 as a fixed calibration-like global_scale value is consistent with the test helper’s formulation and sufficient for benchmarking.


376-383: Geomean speedup vs unfused RMSNorm+fp4_quantize is reported correctly.

Collecting finite speedups and printing the geometric mean relative to the unfused path reflects the intended performance comparison after the refactor.

tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (6)

42-88: Dequantization and global_scale handling align with fused kernel math.

dequantize_fp4_output correctly:

  • Handles 2D and 3D FP4 layouts,
  • Decodes E4M3 vs UE8M0 scales,
  • Divides by global_scale when provided, consistent with block scales that include a global_scale factor.

This matches the quantization math used in rmsnorm_fp4quant and the NVFP4 tests.


199-258: NVFP4 2D/3D tests properly exercise global_scale integration.

The 2D and 3D NVFP4 tests:

  • Compute global_scale from reference RMSNorm outputs,
  • Run rmsnorm_fp4quant with that scale,
  • Dequantize with dequantize_fp4_output(..., global_scale) and compare using tiered tolerances.

This gives strong end-to-end validation of the new global_scale behavior.


444-510: Fused vs separate RMSNorm comparison correctly uses dequant + tiered tolerances.

Comparing dequantized fused outputs (with global_scale) to separate rmsnorm results using assert_close_with_tiered_tolerance is the right level of check, avoiding dependence on internal scale computation details while still validating functional correctness.


512-621: Fused vs separate fp4_quantize tests robustly validate global_scale consistency.

The NVFP4 and MXFP4 tests under TestFusedVsSeparateFP4Quantize:

  • Compare packed FP4 bytes and block-scale tensors between fused and standalone fp4_quantize.
  • Also compare dequantized outputs with tiered tolerances.

These should catch any divergence in how global_scale is applied across paths.


781-837: Large hidden NVFP4 test properly exercises cluster sync path with global_scale.

Sampling a subset of rows, computing global_scale, and validating dequantized outputs against reference RMSNorm helps ensure the cluster-synchronized path behaves correctly under the new scaling scheme.


1075-1321: Auto-allocation tests for RMSNorm-only path mirror add+RMS semantics and look solid.

TestAutoAllocation covers:

  • 2D/3D NVFP4 with global_scale,
  • MXFP4 without global_scale,
  • Swizzled layouts,
  • Parity between preallocated and auto-allocated outputs.

These tests effectively validate the new “optional outputs returning (y_fp4, block_scale)” behavior in rmsnorm_fp4quant.

Comment on lines +1839 to +1841
y_fp4: torch.Tensor | None = None,
block_scale: torch.Tensor | None = None,
global_scale: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Get the function signature and context around the mentioned lines
sed -n '1835,1850p' flashinfer/cute_dsl/rmsnorm_fp4quant.py

Repository: flashinfer-ai/flashinfer

Length of output: 609


🏁 Script executed:

# Check the return statements and key implementation lines
sed -n '1902,1915p' flashinfer/cute_dsl/rmsnorm_fp4quant.py

Repository: flashinfer-ai/flashinfer

Length of output: 593


🏁 Script executed:

# Check line 2013 (return statement)
sed -n '2010,2020p' flashinfer/cute_dsl/rmsnorm_fp4quant.py

Repository: flashinfer-ai/flashinfer

Length of output: 208


🏁 Script executed:

# Search for existing calls to rmsnorm_fp4quant in the codebase
rg -nP --type=py 'rmsnorm_fp4quant\s*\(' -B1 -A3

Repository: flashinfer-ai/flashinfer

Length of output: 19556


🏁 Script executed:

# Check if function has decorators (like @functools.cache or @flashinfer_api)
sed -n '1830,1845p' flashinfer/cute_dsl/rmsnorm_fp4quant.py

Repository: flashinfer-ai/flashinfer

Length of output: 448


Add @functools.cache decorator for module-level caching per coding guidelines.

The function is decorated with @flashinfer_api as required, but lacks the @functools.cache decorator specified in the coding guidelines for flashinfer/**/*.py files. This should be added to implement module-level caching and avoid recompilation.

Additionally, the API signature has been significantly changed:

  • y_fp4 and block_scale are now optional (previously required)
  • global_scale parameter added
  • Function now returns Tuple[torch.Tensor, torch.Tensor] instead of None

These changes are intentional and documented in the docstring. Existing code passing pre-allocated tensors will continue to work (Python allows ignoring return values), but usage patterns differ from the previous in-place operation model.

🤖 Prompt for AI Agents
In flashinfer/cute_dsl/rmsnorm_fp4quant.py around lines 1839-1841, the function
is missing the required module-level caching decorator; add @functools.cache
immediately above the existing @flashinfer_api decorator to enable module-level
caching and prevent recompilation, and ensure functools.cache is imported at the
module level (add "import functools" if not present); do not change the function
signature or behavior beyond adding the decorator.

Comment on lines +1878 to +1883
global_scale : torch.Tensor, optional
Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
If provided, the RMSNorm output is divided by this value before quantization:
``y = rmsnorm(x, w) / global_scale``. This is used for NVFP4 format where
a pre-computed global scale lifts per-block scales into optimal dynamic range.
If ``None``, no global scaling is applied (equivalent to global_scale=1.0).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Clarify global_scale documentation for MXFP4 format.

The docstring states that global_scale is "used for NVFP4 format" but doesn't explicitly mention that it's ignored for MXFP4 (block_size=32 with UE8M0 scale format). This could lead to confusion when users provide global_scale with block_size=32.

Recommend adding a note in the documentation:

🔎 Proposed documentation update
     global_scale : torch.Tensor, optional
         Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
         If provided, the RMSNorm output is divided by this value before quantization:
         ``y = rmsnorm(x, w) / global_scale``. This is used for NVFP4 format where
         a pre-computed global scale lifts per-block scales into optimal dynamic range.
+        **Note**: This parameter is only applicable for NVFP4 (block_size=16 with E4M3
+        scale format). It is ignored for MXFP4 (block_size=32 with UE8M0 scale format).
         If ``None``, no global scaling is applied (equivalent to global_scale=1.0).

This relates to the earlier major issue about validating or supporting global_scale for UE8M0.

Also applies to: 1910-1916

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Dec 30, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !213 has been updated with latest changes, and the CI pipeline #40972173 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, thanks for the great work @bkryu !

Tuple[torch.Tensor, torch.Tensor]
A tuple of ``(y_fp4, block_scale)``:

- ``y_fp4``: Quantized FP4 values packed as uint8.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use float4_e2m1fn_x2 instead for torch 2.8+?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't realize torch.float4_e2m1fn_x2 was available; thanks for pointing this out. Changed the output format (and unit tests accordingly) in the latest commits

@cute.jit
def __call__(
self,
x_ptr: cute.Pointer,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With tvm-ffi enabled (https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html), we can pass cute.Tensor directly instead of cute.Pointer without overhead, I'll create a refactor PR later.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Dec 31, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !213 has been updated with latest changes, and the CI pipeline #41015750 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)

1003-1014: global_scale kernel docstring does not match implemented behavior

The device kernel docstring and __call__ comment state that the kernel “computes 1/global_scale” and applies y = rmsnorm(x, w) / global_scale. In the E4M3 path, the implementation instead:

  • Computes scale_float = global_scale_val * max_abs / FP4_MAX,
  • Stores this (quantized) as the per-block scale, and
  • Uses inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) * global_scale_val ≈ FP4_MAX / max_abs, which is independent of global_scale_val.

This means:

  • The FP4 codes (q) are effectively the same as in the no-global-scale case.
  • The stored block scales are multiplied by global_scale, so dequantization q * block_scale yields outputs ≈ rmsnorm(x, w) * global_scale (not divided).

For UE8M0, global_scale is ignored entirely, as expected.

The docs should be updated to describe the actual behavior: global_scale is folded into the stored E4M3 block scales, leaving quantization codes unchanged while scaling dequantized outputs proportionally. The “1/global_scale” and “/ global_scale” language is misleading given the current math.

Also applies to: 1063-1068, 1090-1108, 1252-1255, 1387-1391, 1396-1401, 1635-1655


1736-1792: Validate global_scale tensor device/dtype/shape in the Python API

The pointer wiring and auto-allocation logic (y_fp4 and block_scale for 2D/3D, swizzled/unswizzled) look solid, and _get_compiled_kernel is correctly cached.

However, rmsnorm_fp4quant assumes:

  • global_scale is on the same CUDA device as input,
  • Has shape (1,), and
  • Has dtype torch.float32,

but doesn’t enforce any of these before passing its data_ptr() to the kernel as a Float32 gmem pointer. If a caller accidentally passes a CPU tensor, different shape, or different dtype, this will yield undefined behavior at the CUDA level rather than a clear Python-side error.

Consider adding cheap upfront validation, e.g.:

  • assert global_scale.device == input.device
  • assert global_scale.dtype == torch.float32
  • assert global_scale.numel() == 1

(or corresponding ValueErrors) before the kernel launch.

Also applies to: 1814-1832, 1939-1997, 2007-2017

♻️ Duplicate comments (3)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)

1835-1846: Clarify global_scale API doc: effect and MXFP4 applicability

The rmsnorm_fp4quant docstring currently says:

  • The RMSNorm output is “divided by” global_scale before quantization (y = rmsnorm(x, w) / global_scale), and
  • Mentions NVFP4 usage but not that global_scale is ignored for MXFP4/UE8M0.

Given the kernel implementation:

  • For E4M3 (NVFP4), global_scale is incorporated into the stored block scales, leading to dequantized outputs proportional to global_scale (while the FP4 codes stay effectively unchanged).
  • For UE8M0 (MXFP4), global_scale is not used at all.

The parameter docs and “Returns” section should be adjusted to:

  • Describe that global_scale is folded into E4M3 block scales and affects the magnitude of dequantized outputs, not that the RMSNorm output is divided by it pre-quantization.
  • Explicitly state that global_scale is only applicable for NVFP4 (block_size=16, E4M3) and is ignored for MXFP4 (block_size=32, UE8M0).

This aligns the public API contract with the actual kernel math and avoids confusion for users calibrating global scales.

Also applies to: 1879-1883, 1902-1909

flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)

1008-1027: Kernel/global_scale docstring still contradicts actual quantization behavior

The AddRMSNormFP4QuantKernel kernel and host __call__ docstrings describe:

  • Computing 1/global_scale, multiplying it into rstd, and
  • Effectively applying y = rmsnorm(h, w) / global_scale.

As in the standalone RMSNorm kernel, the implementation instead:

  • Incorporates global_scale into scale_float = global_scale_val * max_abs * fp4_max_rcp for E4M3 (NVFP4),
  • Uses inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) * global_scale_val, which cancels global_scale so FP4 codes are unchanged relative to the no-global-scale case, and
  • Leaves UE8M0 (MXFP4) ignoring global_scale entirely.

Net effect: dequantized outputs (q * block_scale) are scaled by global_scale for E4M3, not divided by it.

Please update these kernel-level docstrings to reflect:

  • That global_scale is folded into E4M3 block scales (and ignored for UE8M0),
  • That it scales the dequantized result rather than altering rstd directly.

Also applies to: 1068-1072, 1095-1114, 1274-1277, 1369-1382, 1511-1563, 1752-1772, 2076-2096


2281-2335: add_rmsnorm_fp4quant: align global_scale API doc with behavior and validate inputs

The high-level API changes are generally good:

  • Optional y_fp4/block_scale with correct auto-allocation for 2D/3D and swizzled/unswizzled layouts.
  • Return of (y_fp4, block_scale) is consistent with the new fused API design.
  • Use of torch.float4_e2m1fn_x2 and appropriate scale dtypes matches the kernels.

Two issues to address:

  1. Docstring semantics and MXFP4 applicability

    The global_scale parameter doc currently claims:

    • The RMSNorm output is divided by global_scale before quantization (y = rmsnorm(h, w) / global_scale), and
    • Does not state that MXFP4 ignores global_scale.

    In reality (E4M3/NVFP4):

    • global_scale is baked into the block scales; FP4 codes are unchanged, and dequantized outputs scale with global_scale.
    • For UE8M0/MXFP4, global_scale is not used.

    The parameter description should be updated accordingly, and explicitly note that global_scale is only meaningful for NVFP4/E4M3 and ignored for MXFP4/UE8M0.

  2. Runtime validation of global_scale tensor

    As with rmsnorm_fp4quant, the function assumes global_scale is a 1-element torch.float32 tensor on the same CUDA device as input, but does not check:

    • Device equality,
    • Dtype (torch.float32), or
    • Shape/numel (1).

    Passing a CPU tensor or wrong dtype would result in the kernel reading an invalid device pointer. Adding simple validation (or coercing to the correct device/dtype with a small copy) before invoking tensor_api would make this API much safer.

Also applies to: 2353-2367, 2368-2469

🧹 Nitpick comments (10)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (2)

127-170: Unfused benchmark behavior vs MXFP4/global_scale contract

The combined unfused path (add → RMSNorm → fp4_quantize) is structured correctly and includes global_scale for NVFP4. However, the helper and PR description state that global_scale should not be used for MXFP4, while this function would still forward a non-None global_scale if called with block_size=32.

Consider explicitly gating this:

  • Only pass global_scale to fp4_quantize when block_size == 16, or
  • Force global_scale=None when block_size == 32.

This keeps the helper aligned with the documented MXFP4 behavior and avoids surprises if it’s reused with block_size=32.


261-387: run_benchmark orchestration and reporting look good; consider refining exception handling

The benchmark wiring for:

  • Fixed NVFP4 block_size=16 and calibrated global_scale,
  • Fused vs unfused timing and bandwidth computation, and
  • Geomean speedup vs the unfused path

is coherent and matches the rmsnorm-only benchmark style.

The broad except Exception as e/except Exception blocks are acceptable for a CLI benchmark, but they do trip Ruff’s BLE001 and can hide unexpected errors.

If you want to align with the linter while keeping robustness, consider:

  • Catching a narrower set (e.g., RuntimeError) or
  • At least logging the exception type/message more prominently so unexpected failures are obvious.
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (2)

137-179: Unfused RMSNorm+FP4 helper: clarify MXFP4/global_scale behavior

bench_separate_flashinfer correctly sequences RMSNorm then fp4_quantize and times the combined operation. The doc comment says that for MXFP4 (block_size=32) global_scale is not used, but the function still forwards the global_scale argument to fp4_quantize unconditionally.

To keep behavior and docs in sync, consider:

  • Only passing global_scale when block_size == 16 (NVFP4), or
  • Passing global_scale=None in the MXFP4 branch.

This also matches the PR’s guidance that global_scale should not be provided for MXFP4.


280-391: Benchmark harness updates and speedup reporting are coherent; optional refinement to exception handling

The changes to:

  • Fix block_size=16 (NVFP4),
  • Introduce a calibrated global_scale for benchmarking,
  • Report fused time, bandwidth, unfused time, and speedup, and
  • Compute geomean speedup vs the unfused path

are internally consistent and mirror the Add+RMSNorm benchmark.

As in the other benchmark file, the bare except Exception handlers flagged by Ruff are acceptable for a benchmarking script but can obscure unexpected failures. Narrowing the exception type or improving the logged diagnostics would be a low-cost improvement.

tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (4)

42-83: Helper functions for dequantization, global_scale, and tiered tolerance look correct

  • dequantize_fp4_output correctly:

    • Interprets torch.float4_e2m1fn_x2 as uint8 for cast_from_fp4,
    • Applies per-block scales for both E4M3 (float8_e4m3fn) and UE8M0 (uint8 via 2^(ue8m0-127)), and
    • Optionally divides by global_scale to conceptually undo the fused scaling.
  • compute_global_scale matches the benchmark-style formula and uses a reference LLaMA RMSNorm; for test-only use this is fine, though guarding against tensor_amax == 0 would make it more robust.

  • assert_close_with_tiered_tolerance is a good fit for low-precision FP4 comparisons, capturing both “most values tight” and “all values bounded” constraints.

Overall, these helpers provide a solid foundation for the new global_scale-aware tests.

Also applies to: 85-108, 110-148


25-29: Use shared flashinfer.utils helpers for GPU capability checks

The tests currently use a local get_cc() wrapper over torch.cuda.get_device_capability() and custom skip conditions (requires_blackwell).

Per the testing guidelines, these checks should ideally go through flashinfer.utils helpers such as get_compute_capability / is_sm100a_supported to keep skip logic consistent across the suite.

Consider refactoring requires_blackwell() (and any direct CC checks) to delegate to the shared utilities instead of duplicating capability logic here.

Also applies to: 156-168


404-469: Fused vs separate FP4Quant tests are well-designed; minor cleanup around global_scale usage

The new TestFusedVsSeparateFP4Quantize tests:

  • Compare fused Add+RMSNorm+FP4Quant against add + RMSNorm + fp4_quantize for both NVFP4 (block_size=16, E4M3) and MXFP4 (block_size=32, UE8M0).
  • Check:
    • Packed FP4 bytes (view(torch.uint8)),
    • Block scale factors, and
    • Dequantized values via dequantize_fp4_output and the tiered tolerance helper.

This is a strong end-to-end validation that the fused kernels match the standalone fp4_quantize implementation, including global_scale behavior.

One small point: in the MXFP4 test you pass global_scale_val = torch.tensor(1.0, ...) positionally into fp4_quantize even though MXFP4 conceptually doesn’t use global_scale. Keeping this at 1.0 is harmless, but if the underlying API ever tightens its contract for UE8M0, it may be safer to pass global_scale=None explicitly in that branch.

Also applies to: 472-587, 589-683


685-745: Unused y_fp4_gs / y_fp4_no_gs in global_scale consistency test

In test_global_scale_value_consistency, the unpacked variables:

  • y_fp4_gs (line 708),
  • y_fp4_no_gs (line 721),

are never used; only the corresponding block scales are consumed.

To satisfy the linter and clarify intent, you can either:

  • Prefix them with an underscore (_y_fp4_gs, _y_fp4_no_gs), or
  • Assign to _ if you don’t plan to use the outputs.
tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (2)

25-28: Consider using flashinfer.utils functions for GPU capability checks.

The custom get_cc() function works but doesn't follow the coding guidelines. Per the guidelines and learnings, test implementations should use flashinfer.utils functions like get_compute_capability(), is_sm90a_supported(), or is_sm100a_supported() to skip tests on unsupported GPU architectures.

🔎 Suggested refactor
+from flashinfer.utils import get_compute_capability
+
-def get_cc():
-    """Get CUDA compute capability."""
-    major, minor = torch.cuda.get_device_capability()
-    return major * 10 + minor

Then update usages:

 def requires_hopper_or_later():
     """Check if running on Hopper (SM90+) or later GPU."""
-    return get_cc() >= 90
+    return get_compute_capability() >= 90

 def requires_blackwell():
     """Check if running on Blackwell GPU."""
-    return get_cc() >= 100
+    return get_compute_capability() >= 100

Based on coding guidelines: Test implementations should use flashinfer.utils functions for GPU capability checks.


751-751: Use underscore prefix for intentionally unused variables.

Lines 751 and 763 unpack return values but only use the block_scale variables. Per Python convention, use underscore prefix for intentionally unused variables to improve clarity and silence linter warnings.

🔎 Proposed fix
-        y_fp4_gs, block_scale_gs = rmsnorm_fp4quant(
+        _y_fp4_gs, block_scale_gs = rmsnorm_fp4quant(
             x,
             weight,
             global_scale=global_scale,
             eps=eps,
             block_size=block_size,
             is_sf_swizzled_layout=False,
         )

-        y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant(
+        _y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant(
             x,
             weight,
             global_scale=global_scale_one,
             eps=eps,
             block_size=block_size,
             is_sf_swizzled_layout=False,
         )

Also applies to: 763-763

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c660708 and 532c80a.

📒 Files selected for processing (6)
  • benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py
  • benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
  • tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
  • tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧰 Additional context used
📓 Path-based instructions (2)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
  • tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧠 Learnings (4)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation

Applied to files:

  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Applied to files:

  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
  • tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
🧬 Code graph analysis (2)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (3)
benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (3)
  • unfused_operation (156-166)
  • sanity_check_outputs (181-265)
  • run_benchmark (268-392)
flashinfer/norm.py (1)
  • rmsnorm (33-68)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1484-1631)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (6)
  • kernel (1091-1728)
  • fmin_f32 (208-220)
  • cvt_f32_to_e4m3 (462-482)
  • fp8_e4m3_to_f32_and_rcp (486-518)
  • ue8m0_to_output_scale (574-606)
  • get_cute_pointers (1751-1792)
flashinfer/cute_dsl/utils.py (1)
  • make_ptr (175-223)
🪛 Ruff (0.14.10)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py

348-348: Do not catch blind exception: Exception

(BLE001)

tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py

76-76: Avoid specifying long messages outside the exception class

(TRY003)


145-148: Avoid specifying long messages outside the exception class

(TRY003)


708-708: Unpacked variable y_fp4_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


721-721: Unpacked variable y_fp4_no_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py

342-342: Do not catch blind exception: Exception

(BLE001)


355-355: Do not catch blind exception: Exception

(BLE001)

tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py

80-80: Avoid specifying long messages outside the exception class

(TRY003)


151-154: Avoid specifying long messages outside the exception class

(TRY003)


751-751: Unpacked variable y_fp4_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


763-763: Unpacked variable y_fp4_no_gs is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (14)
benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py (2)

76-89: Fused CuTe-DSL benchmark: global_scale threading and new FP4 dtype look correct

The added global_scale parameter is cleanly threaded into add_rmsnorm_fp4quant, and switching y_fp4 to torch.float4_e2m1fn_x2 matches the fused kernel’s new output dtype. The allocation shapes and use of bench_gpu_time remain consistent with the bandwidth model (1 byte per packed FP4 pair).

Also applies to: 105-122


172-257: Sanity check for fused vs separate path with global_scale is well-structured

The sanity check correctly:

  • Uses the new torch.float4_e2m1fn_x2 dtype for FP4 outputs.
  • Propagates global_scale through both fused and separate paths.
  • Compares packed FP4 bytes via .view(torch.uint8) with a reasonable ≥70% match threshold.

This should be a solid guard against regressions in the fused kernel’s global_scale handling.

benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py (2)

87-97: Fused RMSNorm benchmark: global_scale and FP4 dtype integration LGTM

Adding global_scale to bench_cute_dsl and allocating y_fp4 as torch.float4_e2m1fn_x2 matches the fused kernel’s API. The block-scale allocations and scale_format selection stay consistent with NVFP4 (E4M3) vs MXFP4 (UE8M0).

Also applies to: 114-125


181-265: Sanity check with global_scale and new FP4 dtype is sound

The updated sanity_check_outputs:

  • Uses torch.float4_e2m1fn_x2 and the updated fused API,
  • Propagates global_scale through fused and separate paths, and
  • Compares packed FP4 results via .view(torch.uint8) with a ≥70% match threshold.

This is a reasonable and robust check for fused-vs-separate behavior under global scaling.

flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

2181-2215: Pointer wiring for global_scale looks correct

The additions to _get_compiled_kernel and tensor_api to handle a global_scale pointer are consistent with the tensor layout used in the kernel:

  • Dummy and real pointer lists both include a final cutlass.Float32 gmem pointer for the scalar global scale.
  • tensor_api always receives a global_scale tensor and passes it into the compiled kernel invocation.

This matches the kernel’s new parameter list without changing the call sites’ responsibilities.

Also applies to: 2216-2239, 2259-2277

tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py (4)

197-215: Core 2D/3D/NVFP4/MXFP4 tests updated to new FP4 dtype and look solid

Across the main test classes:

  • All y_fp4 tensors now use torch.float4_e2m1fn_x2 with the expected (batch, hidden_size // 2) or (batch, seq_len, hidden_size // 2) shapes.
  • block_scale dtypes and shapes are consistent:
    • E4M3/NVFP4: torch.float8_e4m3fn with (batch, ..., hidden_size // block_size),
    • UE8M0/MXFP4: torch.uint8 with matching shapes.
  • Dequantization checks use either plain torch.testing.assert_close or the tiered helper, with tolerances aligned to FP4 precision and MXFP4’s extra quantization error.

These baseline correctness tests align well with the new kernel outputs and the rest of the PR.

Also applies to: 252-258, 315-323, 364-371, 775-795, 828-833


758-866: Large hidden-size tests with new FP4 dtype remain consistent

The large hidden-size NVFP4/MXFP4 tests:

  • Correctly use torch.float4_e2m1fn_x2 for y_fp4,
  • Preserve expected block_scale dtypes (E4M3 vs UE8M0),
  • Only sample a subset of rows for dequantization to keep runtime manageable.

Given the problem sizes, this strikes a good balance between coverage and test cost.


905-983: Swizzled vs unswizzled tests adapt cleanly to float4_e2m1fn_x2

The swizzled scale-factor tests:

  • Allocate both reference and swizzled y_fp4 as torch.float4_e2m1fn_x2,
  • Compare FP4 outputs via .view(torch.uint8), and
  • Use unswizzle_sf to bring swizzled scales back to row-major for equality checks.

This is a thorough check that the new swizzled layout still matches the unswizzled baseline under the updated dtype.

Also applies to: 985-1053


1056-1284: Auto-allocation tests comprehensively cover NVFP4/MXFP4 and swizzled layouts

The TestAutoAllocation class:

  • Verifies that omitting y_fp4 and block_scale returns correctly-shaped and correctly-typed tensors for:
    • 2D/3D NVFP4,
    • MXFP4, and
    • NVFP4 with swizzled scale layout.
  • Confirms numerical correctness against the LLaMA RMSNorm reference and equality vs preallocated outputs (bitwise via .view(torch.uint8)).

This is excellent coverage of the new allocation semantics and should catch most regressions in the Python wrapper behavior.

tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py (5)

42-89: LGTM!

The global_scale integration in dequantization is correct. The function properly reverses the scaling applied during quantization by dividing the result by global_scale.item() when provided.


92-113: LGTM!

The global_scale computation correctly implements the formula to ensure the dynamic range fits within FP4. Constants and device placement are appropriate.


116-154: LGTM!

The two-tiered tolerance check is well-designed for quantized outputs. The detailed error messages are valuable for debugging quantization mismatches, despite the static analysis warning about message length.


518-787: LGTM!

The new test classes provide excellent coverage:

  • TestFusedVsSeparateFP4Quantize validates consistency between fused and separate quantization paths for both NVFP4 and MXFP4
  • test_global_scale_value_consistency verifies that global_scale correctly scales the block scales
  • Tests use appropriate tolerance checks and cover multiple parameter combinations

1097-1346: LGTM!

The TestAutoAllocation class provides comprehensive coverage of the auto-allocation feature:

  • Tests both 2D and 3D inputs with NVFP4 (including global_scale)
  • Tests MXFP4 format with UE8M0 scales
  • Tests swizzled layout auto-allocation
  • Verifies auto-allocated results match pre-allocated results
  • Proper shape, dtype, and value assertions throughout

@yzh119 yzh119 merged commit 6f1624c into flashinfer-ai:main Jan 1, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants