[Feat] Add RMSNorm NvFp4 Quant Operator (#32612)#32957
[Feat] Add RMSNorm NvFp4 Quant Operator (#32612)#32957sparkecho wants to merge 14 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a new fused RMSNorm + NVFP4 quantization operator, along with its CUDA kernel implementation, C++ bindings, Python integration for fusion, and a comprehensive test. The changes are well-structured and follow established patterns within the codebase. The use of TORCH_CHECK for input validation in the kernel and conditional compilation for different SM architectures are good practices. The integration into the fusion pass ensures that this optimized operator can be leveraged where applicable.
|
Hi @sparkecho, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
ProExpertProg
left a comment
There was a problem hiding this comment.
Looks great overall!
Could you report E2E speedup & accuracy numbers?
| @@ -269,4 +269,54 @@ __inline__ __device__ PackedVec<Type> compute_silu_mul( | |||
| return result; | |||
| } | |||
|
|
|||
| // Compute sum of squares for a PackedVec (8 elements). | |||
There was a problem hiding this comment.
These aren't actually fp4 utils, could you move them to layernorm utils?
There was a problem hiding this comment.
Ok, I'll move them to layernorm utils.
There was a problem hiding this comment.
Hi Luka, I encountered some issues while moving the two new functions I added, compute_packed_sum_squares and compute_rms_norm, from nvfp4_utils.cuh to layernorm_utils.cuh.
Since these functions rely on PackedVec (which is defined in nvfp4_utils.cuh), this creates a dependency where fused_kernels/layernorm_utils.cuh must include fp4/nvfp4_utils.cuh. Meanwhile, fp4/rmsnorm_nvfp4_quant_kernels.cu needs to include fused_kernels/layernorm_utils.cuh. This inclusion chain feels awkward and potentially circular.
I am considering two possible solutions:
Option 1: Place those two functions directly inside fp4/rmsnorm_nvfp4_quant_kernels.cu.
Option 2: Create a new file, fp4/rmsnorm_utils.cuh, and move the functions there.
The ideal solution might be to reorganize the directory structure entirely. From a functional standpoint, compute_packed_sum_squares and compute_rms_norm are indeed very similar to the functions currently in fused_kernels/layernorm_utils.cuh.
What would you recommend in this case?
There was a problem hiding this comment.
Ah yeah, this part of the code doesn't have the best structure. Could you extract the PackedVec util into its own file? And then put the rmsnorm utils next to the other later norm util functions.
(Perhaps in a follow-up) it would be good to see if your layernorm functions outperform the existing ones and if we could use yours in other kernels as well.
There was a problem hiding this comment.
Done. Also fixed hardcoded bfloat16 instances (flagged by Cursor). Could you please take another look at this?
I noticed that PR #32520 merged some FP4 operator optimizations into the main branch, which will cause conflicts with my current code. I'd like to run a benchmark first before rebasing onto main.
|
Also please fix precommit and dco |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
|
11338a0 to
ee54faf
Compare
|
Hi @sparkecho, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
ee54faf to
ff0b38b
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
$ python tests/evals/gsm8k/gsm8k_eval.py --port 8000 PRResults: MAINResults: |
4e1fd88 to
a779957
Compare
|
Hi @sparkecho, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
When you run with -rms_norm, you need to manually enable the fusion, can you try that? Please run the following cases E2E, all on your PR:
|
Benchmark results indicate that -rms_norm yields better performance than +rms_norm. Also, the performance gain from fusion is trivial when compared to the no-fusion case. Case 1: no fusion, -rms_normCase 2: no fusion, +rms_normCase 3: fusion, -rms_normCase 4: fusion, +rms_norm |
|
I don't know that fusion benefit is just trivial - looks decent to me! Could you collect lm_eval numbers, as well as vllm bench latency for the fusion vs no fusion cases (with -rms_norm, which should be the default anyway) |
|
I viewed the gains as trivial because the delta for +rms_norm (910.2 → 912.5) is quite small compared to the -rms_norm case (919.9 → 927.1). Regardless, I'm happy to defer to your expertise here. |
Case 1: no fusion, -rms_normCase 2: no fusion, +rms_normCase 3: fusion, -rms_normCase 4: fusion, +rms_norm |
|
Not trivial, -rms_norm is better and is the default. Not sure why you focused on +rms_norm initially. So let's focus on -rms_norm moving forward |
|
The serving benchmarks seem off - can you rerun with a larger # of requests? And please run vllm bench latency as well for 1 output token and {512, 2048, 8192} input tokens (and batch size 1) |
Got it. I'll rerun the benchmarks with a larger number of requests and collect the latency data for those specific input sizes. Will update you once it's done. |
|
I’m a bit confused by the latest data I've gathered. Could you take a look at the commands I used to make sure the setup is correct?
no fusionfusion |
|
Yeah not sure why this is happening. Have you been able to look at a profile to see what's happening there? |
|
Currently, I only have access to nsys for performance profiling, as ncu permissions are restricted in my current environment. Meanwhile, the servers where I do have ncu access don't yet support FP4. I will proceed with the analysis using nsys for now, while simultaneously looking for an environment that supports both ncu and FP4. |
|
The latest stable release of FlashInfer is v0.6.3, and this version already includes rmsnorm_fp4quant and add_rmsnorm_fp4quant. ... rmsnorm_nvfp4_quant_kernels.cu nvfp4_quant_entry.cu (line 121) torch_bindings.cpp (line 227) In the same FlashInfer version, the corresponding fused APIs are already available in: init.py (line 103) rmsnorm_fp4quant.py (line 761) add_rmsnorm_fp4quant.py (line 1015) |
|
@baonudesifeizhai Thanks for pointing that out. We'll proceed with our current plan for the time being, as we want to run some benchmarks against FlashInfer before finalizing the roadmap. |
| // SF layout pads rows to 128, so we need to process those padded rows too | ||
| int effective_rows = (num_tokens + 127) / 128 * 128; | ||
| dim3 grid( | ||
| std::min(effective_rows, multi_processor_count * num_blocks_per_sm)); |
There was a problem hiding this comment.
I can see you didn't reuse the grid layout from the optimized FP4 quant kernels. See
I would give it a try. It will significantly improve the load balancing and the occupancy of your kernel.
There was a problem hiding this comment.
Thanks for pointing this out! I did notice that implementation, but I wasn't quite sure how to map the 2D grid layout to the RMSNorm calculation. I appreciate the guidance—I'll take another look and give it a try.
| // First pass: compute x = input + residual, update residual, compute | ||
| // variance | ||
| float variance = 0.0f; | ||
| for (int col_idx = threadIdx.x; col_idx < vecs_per_row_padded; |
There was a problem hiding this comment.
Note that changing the grid layout will also eliminate the need for this inner loop. See
You probably just need to add it to the pass in utility/fix_functionalization.py |
Thanks for the pointer! I was a bit lost, but this looks like the right direction. I'll give it a try and report back. |
|
Impressive results. After adding rmsnorm_fp4 and add_rmsnorm_fp4 operators to the no fusion vs. fusion (with FixFunctionalizationPass)
no fusion vs. fusion (without FixFunctionalizationPass)
no fusionfusion (with FixFunctionalizationPass)fusion (without FixFunctionalizationPass) |
|
Great results! Please report speed up numbers again once you apply the optimization suggested by @LopezCastroRoberto |
|
Sounds good. I'll run benchmarks before and after applying the optimization to show the exact improvement. Update here as soon as I have the results. |
|
The benchmark results on the B200 still look a bit inconsistent(performance for PR-Fused is significantly worse than PR-Unfused). I plan to conduct a more thorough evaluation once the grid layout optimizations are finalized. Here is a brief summary of the current findings:
E2E Throughput (guidellm)
Serving Throughput (vllm bench serve)
|
|
@ProExpertProg @LopezCastroRoberto Apologies for the long silence. I’ve been tied up with other commitments and lacked a proper development environment, which put my optimization work on hold. Now that my environment is set up, I’d like to pick this back up. Do you still think it’s worth pursuing this feature? If we proceed, I have a technical concern: reusing the grid and block configurations from nvfp4_quant_kernels.cu implies using a 2D grid. I’ve attempted this implementation, but it resulted in degraded performance. |
|
Yes, let's pursue this. If you prefer we can merge the flashinfer rms-fp4 kernel first, and then tune yours until it's better. Or merge this kernel first and then flashinfer. I don't have a preference, as long as 1. we get this fusion merged (with a speedup) and 2. we are using the fastest kernel we have available. |
|
I think the timeline for performance optimization could be quite uncertain. So I'm planning to first work on FlashInfer's kernel and run some benchmarks. If we see a clear performance improvement, I'll go ahead and merge it first, and then move on to optimizing the CUDA kernel. Do you think this is ok?
|
|
Yep, that sounds great! |
|
We’ve been doing some refactoring to eliminate duplicated helper code for vectorized instructions. Could you please rebase your branch and use the shared helper file instead? #35105 |
|
@sparkecho any update on the flashinfer kernel integration? Feel free to open a new PR if you want |
|
@ProExpertProg Thanks for checking in! Progress has been a bit slower than expected on my end, but I am currently focused on the FlashInfer kernel integration. I’m making steady headway and expect to open a new PR (or submit a fresh update) within the next two days. I really appreciate your patience and the nudge! |
|
Sounds good, looking forward to it! This is one of the few remaining obvious ones for for models like deepseek |


Purpose
This commit implements rmsnorm + fp4 quant fusion, and integrate to rmsnorm + quant fusion pass, fixing #32612
This PR also includes code refactoring for better modularity and maintainability.
To enable the fusion, add the following compilation flags:
Performance data provided below reflects testing conducted on a B200 platform.
E2E results
Dense model
MoE Model
Accuracy test
Dense model
MoE Model
Unit tests
pytest tests/kernels/quantization/test_rmsnorm_nvfp4_quant.py-- ALL PASSEDpytest tests/kernels/quantization/test_nvfp4_quant.py-- ALL PASSEDpytest tests/compile/test_fusion.py-- ALL PASSEDTested Platforms
The RMSNorm + FP4 quantization fusion has been validated on the following platforms: