Skip to content

[Feat] Add RMSNorm NvFp4 Quant Operator (#32612)#32957

Open
sparkecho wants to merge 14 commits intovllm-project:mainfrom
sparkecho:feature/rmsnorm-fp4quant
Open

[Feat] Add RMSNorm NvFp4 Quant Operator (#32612)#32957
sparkecho wants to merge 14 commits intovllm-project:mainfrom
sparkecho:feature/rmsnorm-fp4quant

Conversation

@sparkecho
Copy link

@sparkecho sparkecho commented Jan 23, 2026

Purpose

This commit implements rmsnorm + fp4 quant fusion, and integrate to rmsnorm + quant fusion pass, fixing #32612
This PR also includes code refactoring for better modularity and maintainability.

To enable the fusion, add the following compilation flags:

--compilation-config '{"custom_ops":["+rms_norm"]}'

Performance data provided below reflects testing conducted on a B200 platform.

E2E results

Dense model

Metric MAIN PR Delta Status
Requests/s (Mean) 6.0 6.0 0.0% ✅ Stable
Input Tokens/s (Mean) 648.5 649.8 +0.2% ✅ Improved
Output Tokens/s (Mean) 803.7 805.3 +0.2% ✅ Improved
Total Tokens/s (Mean) 1419.7 1422.5 +0.2% ✅ Improved
vllm serve nvidia/Llama-3.3-70B-Instruct-NVFP4 --compilation-config '{"custom_ops":["+rms_norm"]}'
guidellm benchmark --target "http://localhost:8000/v1" --profile concurrent --rate 10 --data "prompt_tokens=64,output_tokens=128" --max-seconds 30
# MAIN

|============|=====|======|=======|======|=======|=======|========|=======|=======|========|
| Benchmark  | Requests               |||| Input Tokens || Output Tokens || Total Tokens  ||
| Strategy   | Per Sec   || Concurrency || Per Sec      || Per Sec       || Per Sec       ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean  | Mdn    | Mean  | Mdn   | Mean   |
|------------|-----|------|-------|------|-------|-------|--------|-------|-------|--------|
| concurrent | 1.3 | 6.0  | 10.0  | 10.0 | 126.1 | 648.5 | 276.0  | 803.7 | 278.0 | 1419.7 |
|============|=====|======|=======|======|=======|=======|========|=======|=======|========|

# PR

|============|=====|======|=======|======|=======|=======|========|=======|=======|========|
| Benchmark  | Requests               |||| Input Tokens || Output Tokens || Total Tokens  ||
| Strategy   | Per Sec   || Concurrency || Per Sec      || Per Sec       || Per Sec       ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean  | Mdn    | Mean  | Mdn   | Mean   |
|------------|-----|------|-------|------|-------|-------|--------|-------|-------|--------|
| concurrent | 1.3 | 6.0  | 10.0  | 10.0 | 126.3 | 649.8 | 270.3  | 805.3 | 272.6 | 1422.5 |
|============|=====|======|=======|======|=======|=======|========|=======|=======|========|

MoE Model

Metric MAIN PR Delta Status
Requests/s (Mean) 15.0 15.0 0.0% ✅ Stable
Input Tokens/s (Mean) 1117.8 1119.1 +0.1% ✅ Improved
Output Tokens/s (Mdn) 908.1 1414.1 +55.7% 🚀 Significant
Output Tokens/s (Mean) 1946.6 1948.5 +0.1% ✅ Improved
Total Tokens/s (Mean) 3049.0 3050.8 +0.06% ✅ Improved
vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["+rms_norm"]}'
guidellm benchmark --target "http://localhost:8000/v1" --profile concurrent --rate 10 --data "prompt_tokens=64,output_tokens=128" --max-seconds 30
# MAIN

|============|=====|======|=======|======|=======|========|=======|========|=======|========|
| Benchmark  | Requests               |||| Input Tokens  || Output Tokens || Total Tokens  ||
| Strategy   | Per Sec   || Concurrency || Per Sec       || Per Sec       || Per Sec       ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean   | Mdn   | Mean   | Mdn   | Mean   |
|------------|-----|------|-------|------|-------|--------|-------|--------|-------|--------|
| concurrent | 3.1 | 15.0 | 10.0  | 10.0 | 237.4 | 1117.8 | 908.1 | 1946.6 | 926.2 | 3049.0 |
|============|=====|======|=======|======|=======|========|=======|========|=======|========|

# PR

|============|=====|======|=======|======|=======|========|========|========|========|========|
| Benchmark  | Requests               |||| Input Tokens  || Output Tokens  || Total Tokens   ||
| Strategy   | Per Sec   || Concurrency || Per Sec       || Per Sec        || Per Sec        ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean   | Mdn    | Mean   | Mdn    | Mean   |
|------------|-----|------|-------|------|-------|--------|--------|--------|--------|--------|
| concurrent | 3.1 | 15.0 | 10.0  | 10.0 | 225.6 | 1119.1 | 1414.1 | 1948.5 | 1443.1 | 3050.8 |
|============|=====|======|=======|======|=======|========|========|========|========|========|

Accuracy test

Dense model

Metric MAIN PR Delta Status
Accuracy 0.930 0.926 -0.4% ✅ Acceptable
Invalid responses 0.000 0.001 +0.1% ✅ Acceptable
Questions/s 63.698 63.053 -1.0% ✅ Acceptable
Output Tokens/s 5959.65 5899.81 -1.0% ✅ Acceptable
# MAIN

Results:
Accuracy: 0.930
Invalid responses: 0.000
Total latency: 20.707 s
Questions per second: 63.698
Total output tokens: 123407
Output tokens per second: 5959.653

# PR

Results:
Accuracy: 0.926
Invalid responses: 0.001
Total latency: 20.919 s
Questions per second: 63.053
Total output tokens: 123417
Output tokens per second: 5899.810

MoE Model

Metric MAIN PR Delta Status
Accuracy 0.884 0.884 0.0% ✅ Stable
Invalid responses 0.001 0.001 0.0% ✅ Stable
Questions/s 116.100 114.990 -1.0% ✅ Acceptable
Output Tokens/s 13637.47 13508.78 -0.9% ✅ Acceptable
# MAIN

Results:
Accuracy: 0.884
Invalid responses: 0.001
Total latency: 11.361 s
Questions per second: 116.100
Total output tokens: 154934
Output tokens per second: 13637.468

# PR

Results:
Accuracy: 0.884
Invalid responses: 0.001
Total latency: 11.471 s
Questions per second: 114.990
Total output tokens: 154953
Output tokens per second: 13508.784

Unit tests

pytest tests/kernels/quantization/test_rmsnorm_nvfp4_quant.py -- ALL PASSED
pytest tests/kernels/quantization/test_nvfp4_quant.py -- ALL PASSED
pytest tests/compile/test_fusion.py -- ALL PASSED

Tested Platforms

The RMSNorm + FP4 quantization fusion has been validated on the following platforms:

Platform Compute Capability Status
NVIDIA B200 SM100 ✅ Verified
NVIDIA RTX 5090 SM120 ✅ Verified

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the ci/build label Jan 23, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new fused RMSNorm + NVFP4 quantization operator, along with its CUDA kernel implementation, C++ bindings, Python integration for fusion, and a comprehensive test. The changes are well-structured and follow established patterns within the codebase. The use of TORCH_CHECK for input validation in the kernel and conditional compilation for different SM architectures are good practices. The integration into the fusion pass ensures that this optimized operator can be leveraged where applicable.

@mergify
Copy link

mergify bot commented Jan 23, 2026

Hi @sparkecho, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall!

Could you report E2E speedup & accuracy numbers?

@@ -269,4 +269,54 @@ __inline__ __device__ PackedVec<Type> compute_silu_mul(
return result;
}

// Compute sum of squares for a PackedVec (8 elements).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't actually fp4 utils, could you move them to layernorm utils?

Copy link
Author

@sparkecho sparkecho Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll move them to layernorm utils.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Luka, I encountered some issues while moving the two new functions I added, compute_packed_sum_squares and compute_rms_norm, from nvfp4_utils.cuh to layernorm_utils.cuh.

Since these functions rely on PackedVec (which is defined in nvfp4_utils.cuh), this creates a dependency where fused_kernels/layernorm_utils.cuh must include fp4/nvfp4_utils.cuh. Meanwhile, fp4/rmsnorm_nvfp4_quant_kernels.cu needs to include fused_kernels/layernorm_utils.cuh. This inclusion chain feels awkward and potentially circular.

I am considering two possible solutions:
Option 1: Place those two functions directly inside fp4/rmsnorm_nvfp4_quant_kernels.cu.
Option 2: Create a new file, fp4/rmsnorm_utils.cuh, and move the functions there.

The ideal solution might be to reorganize the directory structure entirely. From a functional standpoint, compute_packed_sum_squares and compute_rms_norm are indeed very similar to the functions currently in fused_kernels/layernorm_utils.cuh.

What would you recommend in this case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, this part of the code doesn't have the best structure. Could you extract the PackedVec util into its own file? And then put the rmsnorm utils next to the other later norm util functions.

(Perhaps in a follow-up) it would be good to see if your layernorm functions outperform the existing ones and if we could use yours in other kernels as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Also fixed hardcoded bfloat16 instances (flagged by Cursor). Could you please take another look at this?
I noticed that PR #32520 merged some FP4 operator optimizations into the main branch, which will cause conflicts with my current code. I'd like to run a benchmark first before rebasing onto main.

@ProExpertProg
Copy link
Collaborator

Also please fix precommit and dco

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

@sparkecho
Copy link
Author

Also please fix precommit and dco
Sure.

@mergify
Copy link

mergify bot commented Jan 24, 2026

Hi @sparkecho, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@sparkecho sparkecho force-pushed the feature/rmsnorm-fp4quant branch from ee54faf to ff0b38b Compare January 24, 2026 11:05
@mergify mergify bot added the nvidia label Jan 25, 2026
@mergify
Copy link

mergify bot commented Jan 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sparkecho.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sparkecho
Copy link
Author

$ python tests/evals/gsm8k/gsm8k_eval.py --port 8000

PR

Results:
Accuracy: 0.874
Invalid responses: 0.000
Total latency: 29.369 s
Questions per second: 44.912
Total output tokens: 154000
Output tokens per second: 5243.655

MAIN

Results:
Accuracy: 0.877
Invalid responses: 0.000
Total latency: 29.393 s
Questions per second: 44.875
Total output tokens: 154299
Output tokens per second: 5249.573

@sparkecho sparkecho force-pushed the feature/rmsnorm-fp4quant branch from 4e1fd88 to a779957 Compare January 31, 2026 14:43
@mergify mergify bot removed the needs-rebase label Jan 31, 2026
@mergify
Copy link

mergify bot commented Jan 31, 2026

Hi @sparkecho, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@sparkecho sparkecho changed the title [WIP][Feat] Add RMSNorm NvFp4 Quant Operator (#32612) [Feat] Add RMSNorm NvFp4 Quant Operator (#32612) Jan 31, 2026
@ProExpertProg
Copy link
Collaborator

When you run with -rms_norm, you need to manually enable the fusion, can you try that?

Please run the following cases E2E, all on your PR:

  • no fusion, -rms_norm: -cc.custom_ops+=-rms_norm -cc.pass_config.fuse_rms_quant=False
  • no fusion, +rms_norm: -cc.custom_ops+=+rms_norm -cc.pass_config.fuse_rms_quant=False
  • fusion, -rms_norm: -cc.custom_ops+=-rms_norm -cc.pass_config.fuse_rms_quant=True
  • fusion,+rms_norm: -cc.custom_ops+=+rms_norm -cc.pass_config.fuse_rms_quant=True

@sparkecho
Copy link
Author

sparkecho commented Feb 9, 2026

When you run with -rms_norm, you need to manually enable the fusion, can you try that?

Please run the following cases E2E, all on your PR:

* no fusion, -rms_norm: `-cc.custom_ops+=-rms_norm -cc.pass_config.fuse_rms_quant=False`

* no fusion, +rms_norm: `-cc.custom_ops+=+rms_norm -cc.pass_config.fuse_rms_quant=False`

* fusion, -rms_norm: `-cc.custom_ops+=-rms_norm -cc.pass_config.fuse_rms_quant=True`

* fusion,+rms_norm: `-cc.custom_ops+=+rms_norm -cc.pass_config.fuse_rms_quant=True`

Benchmark results indicate that -rms_norm yields better performance than +rms_norm. Also, the performance gain from fusion is trivial when compared to the no-fusion case.

Case 1: no fusion, -rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":false}}'

|============|=====|======|=======|======|=======|=======|========|=======|=======|========|
| Benchmark  | Requests               |||| Input Tokens || Output Tokens || Total Tokens  ||
| Strategy   | Per Sec   || Concurrency || Per Sec      || Per Sec       || Per Sec       ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean  | Mdn    | Mean  | Mdn   | Mean   |
|------------|-----|------|-------|------|-------|-------|--------|-------|-------|--------|
| concurrent | 1.5 | 7.0  | 10.0  | 10.0 | 111.7 | 540.0 | 454.3  | 919.9 | 460.8 | 1442.3 |
|============|=====|======|=======|======|=======|=======|========|=======|=======|========|

Case 2: no fusion, +rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["+rms_norm"], "pass_config":{"fuse_norm_quant":false}}'

|============|=====|======|=======|======|=======|=======|========|=======|=======|========|
| Benchmark  | Requests               |||| Input Tokens || Output Tokens || Total Tokens  ||
| Strategy   | Per Sec   || Concurrency || Per Sec      || Per Sec       || Per Sec       ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean  | Mdn    | Mean  | Mdn   | Mean   |
|------------|-----|------|-------|------|-------|-------|--------|-------|-------|--------|
| concurrent | 1.5 | 7.0  | 10.0  | 10.0 | 105.5 | 534.8 | 429.4  | 910.2 | 441.1 | 1427.8 |
|============|=====|======|=======|======|=======|=======|========|=======|=======|========|

Case 3: fusion, -rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'

|============|=====|======|=======|======|=======|=======|========|=======|=======|========|
| Benchmark  | Requests               |||| Input Tokens || Output Tokens || Total Tokens  ||
| Strategy   | Per Sec   || Concurrency || Per Sec      || Per Sec       || Per Sec       ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean  | Mdn    | Mean  | Mdn   | Mean   |
|------------|-----|------|-------|------|-------|-------|--------|-------|-------|--------|
| concurrent | 1.5 | 7.0  | 10.0  | 10.0 | 147.1 | 530.6 | 219.8  | 927.1 | 250.3 | 1450.2 |
|============|=====|======|=======|======|=======|=======|========|=======|=======|========|

Case 4: fusion, +rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["+rms_norm"], "pass_config":{"fuse_norm_quant":true}}'

|============|=====|======|=======|======|=======|=======|========|=======|=======|========|
| Benchmark  | Requests               |||| Input Tokens || Output Tokens || Total Tokens  ||
| Strategy   | Per Sec   || Concurrency || Per Sec      || Per Sec       || Per Sec       ||
|            | Mdn | Mean | Mdn   | Mean | Mdn   | Mean  | Mdn    | Mean  | Mdn   | Mean   |
|------------|-----|------|-------|------|-------|-------|--------|-------|-------|--------|
| concurrent | 2.9 | 7.0  | 10.0  | 10.0 | 216.5 | 536.8 | 455.3  | 912.5 | 457.8 | 1427.6 |
|============|=====|======|=======|======|=======|=======|========|=======|=======|========|

@ProExpertProg
Copy link
Collaborator

I don't know that fusion benefit is just trivial - looks decent to me! Could you collect lm_eval numbers, as well as vllm bench latency for the fusion vs no fusion cases (with -rms_norm, which should be the default anyway)

@sparkecho
Copy link
Author

I viewed the gains as trivial because the delta for +rms_norm (910.2 → 912.5) is quite small compared to the -rms_norm case (919.9 → 927.1). Regardless, I'm happy to defer to your expertise here.
I have the lm_eval and bench latency runs in progress and will update you shortly. Appreciate the input.

@sparkecho
Copy link
Author

Case 1: no fusion, -rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":false}}'

# latency
vllm bench serve --input-len 100 --output-len 100 --num-prompts 8
============ Serving Benchmark Result ============
Successful requests:                     8
Failed requests:                         0
Benchmark duration (s):                  1.00
Total input tokens:                      800
Total generated tokens:                  800
Request throughput (req/s):              8.01
Output token throughput (tok/s):         800.73
Peak output token throughput (tok/s):    800.00
Peak concurrent requests:                8.00
Total token throughput (tok/s):          1601.46
---------------Time to First Token----------------
Mean TTFT (ms):                          108.47
Median TTFT (ms):                        108.58
P99 TTFT (ms):                           109.04
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.96
Median TPOT (ms):                        8.98
P99 TPOT (ms):                           8.98
---------------Inter-token Latency----------------
Mean ITL (ms):                           8.96
Median ITL (ms):                         8.99
P99 ITL (ms):                            9.60
==================================================

# lm_eval
vllm ({'pretrained': 'nvidia/Qwen3-30B-A3B-NVFP4', 'quantization': 'modelopt_fp4', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'compilation_config': {'custom_ops': ['none', '-rms_norm'], 'pass_config': {'fuse_norm_quant': False}}}), gen_kwargs: (temperature=0.0), limit: 500, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.888|±  |0.0141|
|     |       |strict-match    |     5|exact_match|↑  |0.892|±  |0.0139|

Case 2: no fusion, +rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["+rms_norm"], "pass_config":{"fuse_norm_quant":false}}'

# latency
============ Serving Benchmark Result ============
Successful requests:                     8
Failed requests:                         0
Benchmark duration (s):                  1.07
Total input tokens:                      800
Total generated tokens:                  800
Request throughput (req/s):              7.50
Output token throughput (tok/s):         749.61
Peak output token throughput (tok/s):    737.00
Peak concurrent requests:                8.00
Total token throughput (tok/s):          1499.22
---------------Time to First Token----------------
Mean TTFT (ms):                          152.32
Median TTFT (ms):                        152.39
P99 TTFT (ms):                           153.03
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.22
Median TPOT (ms):                        9.23
P99 TPOT (ms):                           9.23
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.22
Median ITL (ms):                         9.25
P99 ITL (ms):                            9.75
==================================================

# lm_eval
vllm ({'pretrained': 'nvidia/Qwen3-30B-A3B-NVFP4', 'quantization': 'modelopt_fp4', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'compilation_config': {'custom_ops': ['none', '+rms_norm'], 'pass_config': {'fuse_norm_quant': False}}}), gen_kwargs: (temperature=0.0), limit: 500, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.890|±  |0.0140|
|     |       |strict-match    |     5|exact_match|↑  |0.878|±  |0.0147|

Case 3: fusion, -rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'

# latency
============ Serving Benchmark Result ============
Successful requests:                     8
Failed requests:                         0
Benchmark duration (s):                  0.99
Total input tokens:                      800
Total generated tokens:                  800
Request throughput (req/s):              8.09
Output token throughput (tok/s):         808.82
Peak output token throughput (tok/s):    800.00
Peak concurrent requests:                8.00
Total token throughput (tok/s):          1617.64
---------------Time to First Token----------------
Mean TTFT (ms):                          83.16
Median TTFT (ms):                        82.52
P99 TTFT (ms):                           84.81
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.10
Median TPOT (ms):                        9.09
P99 TPOT (ms):                           9.12
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.10
Median ITL (ms):                         9.08
P99 ITL (ms):                            10.12
==================================================

# lm_eval
vllm ({'pretrained': 'nvidia/Qwen3-30B-A3B-NVFP4', 'quantization': 'modelopt_fp4', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'compilation_config': {'custom_ops': ['none', '-rms_norm'], 'pass_config': {'fuse_norm_quant': True}}}), gen_kwargs: (temperature=0.0), limit: 500, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.88|±  |0.0145|
|     |       |strict-match    |     5|exact_match|↑  | 0.88|±  |0.0145|

Case 4: fusion, +rms_norm

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["+rms_norm"], "pass_config":{"fuse_norm_quant":true}}'

# latency
============ Serving Benchmark Result ============
Successful requests:                     8
Failed requests:                         0
Benchmark duration (s):                  0.95
Total input tokens:                      800
Total generated tokens:                  800
Request throughput (req/s):              8.46
Output token throughput (tok/s):         846.24
Peak output token throughput (tok/s):    800.00
Peak concurrent requests:                8.00
Total token throughput (tok/s):          1692.48
---------------Time to First Token----------------
Mean TTFT (ms):                          36.16
Median TTFT (ms):                        36.88
P99 TTFT (ms):                           37.54
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.16
Median TPOT (ms):                        9.16
P99 TPOT (ms):                           9.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.16
Median ITL (ms):                         9.22
P99 ITL (ms):                            9.85
==================================================

# lm_eval
vllm ({'pretrained': 'nvidia/Qwen3-30B-A3B-NVFP4', 'quantization': 'modelopt_fp4', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'compilation_config': {'custom_ops': ['none', '+rms_norm'], 'pass_config': {'fuse_norm_quant': True}}}), gen_kwargs: (temperature=0.0), limit: 500, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.878|±  |0.0147|
|     |       |strict-match    |     5|exact_match|↑  |0.888|±  |0.0141|

@ProExpertProg
Copy link
Collaborator

Not trivial, -rms_norm is better and is the default. Not sure why you focused on +rms_norm initially. So let's focus on -rms_norm moving forward

@ProExpertProg
Copy link
Collaborator

The serving benchmarks seem off - can you rerun with a larger # of requests? And please run vllm bench latency as well for 1 output token and {512, 2048, 8192} input tokens (and batch size 1)

@sparkecho
Copy link
Author

The serving benchmarks seem off - can you rerun with a larger # of requests? And please run vllm bench latency as well for 1 output token and {512, 2048, 8192} input tokens (and batch size 1)

Got it. I'll rerun the benchmarks with a larger number of requests and collect the latency data for those specific input sizes. Will update you once it's done.

@sparkecho
Copy link
Author

I’m a bit confused by the latest data I've gathered. Could you take a look at the commands I used to make sure the setup is correct?

Metric No Fusion Fusion Delta
Throughput (req/s) 91.51 91.24 -0.3%
Output tok/s 9151.10 9123.70 -0.3%
Mean TTFT (ms) 1737.80 1733.35 -0.3%
P99 ITL (ms) 42.77 32.25 -24.6%
Decode Latency (512) 19.96ms 19.73ms -1.1%
Decode Latency (2048) 43.47ms 44.40ms +2.1%
Decode Latency (8192) 242.32ms 243.33ms +0.4%

no fusion

# Throughput

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":false}}'
vllm bench serve --input-len 100 --output-len 100 --num-prompts 512

============ Serving Benchmark Result ============
Successful requests:                     512
Failed requests:                         0
Benchmark duration (s):                  5.59
Total input tokens:                      51200
Total generated tokens:                  51200
Request throughput (req/s):              91.51
Output token throughput (tok/s):         9151.10
Peak output token throughput (tok/s):    10194.00
Peak concurrent requests:                512.00
Total token throughput (tok/s):          18302.19
---------------Time to First Token----------------
Mean TTFT (ms):                          1737.80
Median TTFT (ms):                        1214.82
P99 TTFT (ms):                           3132.17
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.11
Median TPOT (ms):                        25.18
P99 TPOT (ms):                           25.41
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.12
Median ITL (ms):                         25.41
P99 ITL (ms):                            42.77
==================================================

# Latency
for INPUT_LEN in 512 2048 8192; do
  vllm bench latency --input-len $INPUT_LEN --output-len 1 --batch-size 1 \
    --model nvidia/Qwen3-30B-A3B-NVFP4 \
    --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":false}}'
done

## Latency(512)
Avg latency: 0.01995649899666508 seconds
10% percentile latency: 0.01870999380480498 seconds
25% percentile latency: 0.020106525626033545 seconds
50% percentile latency: 0.020304277422837913 seconds
75% percentile latency: 0.020339994109235704 seconds
90% percentile latency: 0.02036972565110773 seconds
99% percentile latency: 0.020474462660495192 seconds

## Latency(2048)
Avg latency: 0.043471691710874436 seconds
10% percentile latency: 0.04297889433801174 seconds
25% percentile latency: 0.0432516397559084 seconds
50% percentile latency: 0.043448852957226336 seconds
75% percentile latency: 0.043724314542487264 seconds
90% percentile latency: 0.04381812172941864 seconds
99% percentile latency: 0.044264794166665525 seconds

## Latency(8192)
Avg latency: 0.2423217370407656 seconds
10% percentile latency: 0.24142448962666094 seconds
25% percentile latency: 0.24152780405711383 seconds
50% percentile latency: 0.24175500660203397 seconds
75% percentile latency: 0.2434669691720046 seconds
90% percentile latency: 0.24388617349322886 seconds
99% percentile latency: 0.24417069428600371 seconds

fusion

# Throughput

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'
vllm bench serve --input-len 100 --output-len 100 --num-prompts 512

============ Serving Benchmark Result ============
Successful requests:                     512
Failed requests:                         0
Benchmark duration (s):                  5.61
Total input tokens:                      51200
Total generated tokens:                  51200
Request throughput (req/s):              91.24
Output token throughput (tok/s):         9123.70
Peak output token throughput (tok/s):    10239.00
Peak concurrent requests:                512.00
Total token throughput (tok/s):          18247.40
---------------Time to First Token----------------
Mean TTFT (ms):                          1733.35
Median TTFT (ms):                        1271.18
P99 TTFT (ms):                           3122.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.20
Median TPOT (ms):                        25.21
P99 TPOT (ms):                           25.44
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.21
Median ITL (ms):                         25.51
P99 ITL (ms):                            32.25
==================================================

# Latency
for INPUT_LEN in 512 2048 8192; do
  vllm bench latency --input-len $INPUT_LEN --output-len 1 --batch-size 1 \
    --model nvidia/Qwen3-30B-A3B-NVFP4 \
    --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'
done

## Latency(512)
Avg latency: 0.01973303483488659 seconds
10% percentile latency: 0.019426920311525465 seconds
25% percentile latency: 0.0195267666131258 seconds
50% percentile latency: 0.019802328431978822 seconds
75% percentile latency: 0.01986185263376683 seconds
90% percentile latency: 0.01993779807817191 seconds
99% percentile latency: 0.020094353535678237 seconds

## Latency(2048)
Avg latency: 0.0444025077080975 seconds
10% percentile latency: 0.04383873753249645 seconds
25% percentile latency: 0.044155212701298296 seconds
50% percentile latency: 0.04434764594770968 seconds
75% percentile latency: 0.04478400811785832 seconds
90% percentile latency: 0.04498820386361331 seconds
99% percentile latency: 0.045022879024036226 seconds

## Latency(8192)
Avg latency: 0.24332722559726486 seconds
10% percentile latency: 0.24282235836144536 seconds
25% percentile latency: 0.243139213998802 seconds
50% percentile latency: 0.24333323596511036 seconds
75% percentile latency: 0.24347014614613727 seconds
90% percentile latency: 0.2437581259990111 seconds
99% percentile latency: 0.2441371555789374 seconds

@ProExpertProg
Copy link
Collaborator

Yeah not sure why this is happening. Have you been able to look at a profile to see what's happening there?

@sparkecho
Copy link
Author

Currently, I only have access to nsys for performance profiling, as ncu permissions are restricted in my current environment. Meanwhile, the servers where I do have ncu access don't yet support FP4. I will proceed with the analysis using nsys for now, while simultaneously looking for an environment that supports both ncu and FP4.

@sparkecho
Copy link
Author

I’ve identified the likely cause based on the nsys traces:

  • Unfused Mode: The expected RMSNorm + FP4 fusion does not trigger. Instead, a Triton-generated RMS fusion operator is used.
  • Fused Mode: While the RMSNorm + FP4 fusion is active, it introduces an additional triton_poi_fused_to__copy_to_1 operator (likely handling data type conversions), which offsets the performance gains.

Please refer to the two profile screenshots below for details.
Clipboard_Screenshot_1771248070
Clipboard_Screenshot_1771248080

I am currently working on a fix to eliminate this redundant operator.

@baonudesifeizhai
Copy link
Contributor

baonudesifeizhai commented Feb 17, 2026

The latest stable release of FlashInfer is v0.6.3, and this version already includes rmsnorm_fp4quant and add_rmsnorm_fp4quant. ...
In the current branch, the in-house implementation is located in:

rmsnorm_nvfp4_quant_kernels.cu

nvfp4_quant_entry.cu (line 121)

torch_bindings.cpp (line 227)

In the same FlashInfer version, the corresponding fused APIs are already available in:

init.py (line 103)

rmsnorm_fp4quant.py (line 761)

add_rmsnorm_fp4quant.py (line 1015)

@sparkecho
Copy link
Author

@baonudesifeizhai Thanks for pointing that out. We'll proceed with our current plan for the time being, as we want to run some benchmarks against FlashInfer before finalizing the roadmap.

// SF layout pads rows to 128, so we need to process those padded rows too
int effective_rows = (num_tokens + 127) / 128 * 128;
dim3 grid(
std::min(effective_rows, multi_processor_count * num_blocks_per_sm));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see you didn't reuse the grid layout from the optimized FP4 quant kernels. See

dim3 grid(grid_x, grid_y);

I would give it a try. It will significantly improve the load balancing and the occupancy of your kernel.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! I did notice that implementation, but I wasn't quite sure how to map the 2D grid layout to the RMSNorm calculation. I appreciate the guidance—I'll take another look and give it a try.

// First pass: compute x = input + residual, update residual, compute
// variance
float variance = 0.0f;
for (int col_idx = threadIdx.x; col_idx < vecs_per_row_padded;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that changing the grid layout will also eliminate the need for this inner loop. See

int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;

@ProExpertProg
Copy link
Collaborator

I am currently working on a fix to eliminate this redundant operator.

You probably just need to add it to the pass in utility/fix_functionalization.py

@sparkecho
Copy link
Author

I am currently working on a fix to eliminate this redundant operator.

You probably just need to add it to the pass in utility/fix_functionalization.py

Thanks for the pointer! I was a bit lost, but this looks like the right direction. I'll give it a try and report back.

@sparkecho
Copy link
Author

Impressive results. After adding rmsnorm_fp4 and add_rmsnorm_fp4 operators to the FixFunctionalizationPass in utility/fix_functionalization.py, the benchmark data shows a significant performance boost. I'll rebase onto main and run a full benchmark suite on B200.

no fusion vs. fusion (with FixFunctionalizationPass)

Metric No Fusion Fusion Delta
Throughput (req/s) 86.72 90.26 +4.1%
Output tok/s 8672.10 9026.01 +4.1%
Total tok/s 17344.21 18052.01 +4.1%
Mean TTFT (ms) 1954.37 1769.65 -9.5%
Median TTFT (ms) 1097.47 1256.94 +14.5%
P99 TTFT (ms) 3324.13 3132.76 -5.8%
Mean TPOT (ms) 25.83 25.47 -1.4%
Mean ITL (ms) 25.85 25.48 -1.4%
P99 ITL (ms) 39.08 37.96 -2.9%
Decode Latency (512) 21.50ms 20.50ms -4.6%
Decode Latency (2048) 59.28ms 53.38ms -9.9%
Decode Latency (8192) 241.07ms 240.62ms -0.2%

no fusion vs. fusion (without FixFunctionalizationPass)

Metric No Fusion Fusion (w/o FixFunc) Delta
Throughput (req/s) 86.72 88.39 +1.9%
Output tok/s 8672.10 8838.71 +1.9%
Total tok/s 17344.21 17677.43 +1.9%
Mean TTFT (ms) 1954.37 1851.28 -5.3%
Median TTFT (ms) 1097.47 1177.06 +7.2%
P99 TTFT (ms) 3324.13 3277.02 -1.4%
Mean TPOT (ms) 25.83 25.59 -0.9%
Mean ITL (ms) 25.85 25.60 -1.0%
P99 ITL (ms) 39.08 38.45 -1.6%
Decode Latency (512) 21.50ms 20.57ms -4.3%
Decode Latency (2048) 59.28ms 52.77ms -11.0%
Decode Latency (8192) 241.07ms 239.56ms -0.6%

no fusion

# Throughput

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":false}}'
vllm bench serve --input-len 100 --output-len 100 --num-prompts 512

============ Serving Benchmark Result ============
Successful requests:                     512
Failed requests:                         0
Benchmark duration (s):                  5.90
Total input tokens:                      51200
Total generated tokens:                  51200
Request throughput (req/s):              86.72
Output token throughput (tok/s):         8672.10
Peak output token throughput (tok/s):    9983.00
Peak concurrent requests:                512.00
Total token throughput (tok/s):          17344.21
---------------Time to First Token----------------
Mean TTFT (ms):                          1954.37
Median TTFT (ms):                        1097.47
P99 TTFT (ms):                           3324.13
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.83
Median TPOT (ms):                        25.86
P99 TPOT (ms):                           26.12
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.85
Median ITL (ms):                         25.89
P99 ITL (ms):                            39.08
==================================================

# Latency
for INPUT_LEN in 512 2048 8192; do
  vllm bench latency --input-len $INPUT_LEN --output-len 1 --batch-size 1 \
    --model nvidia/Qwen3-30B-A3B-NVFP4 \
    --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":false}}'
done

## Latency(512)
Avg latency: 0.021495969500392675 seconds
10% percentile latency: 0.021332726965192707 seconds
25% percentile latency: 0.021444862009957433 seconds
50% percentile latency: 0.021524887008126825 seconds
75% percentile latency: 0.02157235928461887 seconds
90% percentile latency: 0.021606736024841665 seconds
99% percentile latency: 0.021739364502718673 seconds

## Latency(2048)
Avg latency: 0.05927702952952434 seconds
10% percentile latency: 0.0530709809041582 seconds
25% percentile latency: 0.05321109073702246 seconds
50% percentile latency: 0.06110433943103999 seconds
75% percentile latency: 0.06457787402905524 seconds
90% percentile latency: 0.06492912186076864 seconds
99% percentile latency: 0.06639713333337569 seconds

## Latency(8192)
Avg latency: 0.24107048839796336 seconds
10% percentile latency: 0.2404127199668437 seconds
25% percentile latency: 0.2408378387335688 seconds
50% percentile latency: 0.2411143515491858 seconds
75% percentile latency: 0.24146135119372047 seconds
90% percentile latency: 0.24170087215024977 seconds
99% percentile latency: 0.2420339530345518 seconds

fusion (with FixFunctionalizationPass)

# Throughput

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'
vllm bench serve --input-len 100 --output-len 100 --num-prompts 512

============ Serving Benchmark Result ============
Successful requests:                     512
Failed requests:                         0
Benchmark duration (s):                  5.67
Total input tokens:                      51200
Total generated tokens:                  51200
Request throughput (req/s):              90.26
Output token throughput (tok/s):         9026.01
Peak output token throughput (tok/s):    10029.00
Peak concurrent requests:                512.00
Total token throughput (tok/s):          18052.01
---------------Time to First Token----------------
Mean TTFT (ms):                          1769.65
Median TTFT (ms):                        1256.94
P99 TTFT (ms):                           3132.76
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.47
Median TPOT (ms):                        25.51
P99 TPOT (ms):                           25.76
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.48
Median ITL (ms):                         25.61
P99 ITL (ms):                            37.96
==================================================

# Latency
for INPUT_LEN in 512 2048 8192; do
  vllm bench latency --input-len $INPUT_LEN --output-len 1 --batch-size 1 \
    --model nvidia/Qwen3-30B-A3B-NVFP4 \
    --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'
done

## Latency(512)
Avg latency: 0.020501264669777204 seconds
10% percentile latency: 0.01933318405644968 seconds
25% percentile latency: 0.019893383519956842 seconds
50% percentile latency: 0.02081199048552662 seconds
75% percentile latency: 0.021187862526858225 seconds
90% percentile latency: 0.021234074048697947 seconds
99% percentile latency: 0.02125388088868931 seconds

## Latency(2048)
Avg latency: 0.053381663265948495 seconds
10% percentile latency: 0.05299026691354811 seconds
25% percentile latency: 0.0531004749936983 seconds
50% percentile latency: 0.053239979955833405 seconds
75% percentile latency: 0.05336779504432343 seconds
90% percentile latency: 0.05346503800246864 seconds
99% percentile latency: 0.056105885701254014 seconds

## Latency(8192)
Avg latency: 0.24062173870624975 seconds
10% percentile latency: 0.23964396909577773 seconds
25% percentile latency: 0.24017414401168935 seconds
50% percentile latency: 0.24075220903614536 seconds
75% percentile latency: 0.24120976153062657 seconds
90% percentile latency: 0.24142591102281585 seconds
99% percentile latency: 0.24174504493945279 seconds

fusion (without FixFunctionalizationPass)

# Throughput

vllm serve nvidia/Qwen3-30B-A3B-NVFP4 --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'
vllm bench serve --input-len 100 --output-len 100 --num-prompts 512

============ Serving Benchmark Result ============
Successful requests:                     512
Failed requests:                         0
Benchmark duration (s):                  5.79
Total input tokens:                      51200
Total generated tokens:                  51200
Request throughput (req/s):              88.39
Output token throughput (tok/s):         8838.71
Peak output token throughput (tok/s):    9983.00
Peak concurrent requests:                512.00
Total token throughput (tok/s):          17677.43
---------------Time to First Token----------------
Mean TTFT (ms):                          1851.28
Median TTFT (ms):                        1177.06
P99 TTFT (ms):                           3277.02
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.59
Median TPOT (ms):                        25.62
P99 TPOT (ms):                           25.95
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.60
Median ITL (ms):                         25.77
P99 ITL (ms):                            38.45
==================================================

# Latency
for INPUT_LEN in 512 2048 8192; do
  vllm bench latency --input-len $INPUT_LEN --output-len 1 --batch-size 1 \
    --model nvidia/Qwen3-30B-A3B-NVFP4 \
    --compilation-config '{"custom_ops":["-rms_norm"], "pass_config":{"fuse_norm_quant":true}}'
done

## Latency(512)
Avg latency: 0.020568706091338148 seconds
10% percentile latency: 0.020496156939771026 seconds
25% percentile latency: 0.02052181851468049 seconds
50% percentile latency: 0.02056680101668462 seconds
75% percentile latency: 0.02060189424082637 seconds
90% percentile latency: 0.02066828014794737 seconds
99% percentile latency: 0.02073630250059068 seconds

## Latency(2048)
Avg latency: 0.052767693092270446 seconds
10% percentile latency: 0.052323477936442944 seconds
25% percentile latency: 0.05266011622734368 seconds
50% percentile latency: 0.05277614848455414 seconds
75% percentile latency: 0.05292663222644478 seconds
90% percentile latency: 0.05306147980736568 seconds
99% percentile latency: 0.05362981965998188 seconds

## Latency(8192)
Avg latency: 0.23955883210292087 seconds
10% percentile latency: 0.2389347587362863 seconds
25% percentile latency: 0.23916606372222304 seconds
50% percentile latency: 0.23951216298155487 seconds
75% percentile latency: 0.2399046059581451 seconds
90% percentile latency: 0.2401798002421856 seconds
99% percentile latency: 0.2406992010702379 seconds

@ProExpertProg
Copy link
Collaborator

Great results! Please report speed up numbers again once you apply the optimization suggested by @LopezCastroRoberto

@sparkecho
Copy link
Author

Sounds good. I'll run benchmarks before and after applying the optimization to show the exact improvement. Update here as soon as I have the results.

@sparkecho
Copy link
Author

The benchmark results on the B200 still look a bit inconsistent(performance for PR-Fused is significantly worse than PR-Unfused). I plan to conduct a more thorough evaluation once the grid layout optimizations are finalized. Here is a brief summary of the current findings:

Abbreviation Branch Compilation Config
PR-Fused PR custom_ops:["-rms_norm"], fuse_norm_quant:true
PR-Unfused PR custom_ops:["-rms_norm"], fuse_norm_quant:false
PR-Default PR No compilation-config argument
Main-Fused Main custom_ops:["-rms_norm"], fuse_norm_quant:true
Main-Default Main No compilation-config argument

E2E Throughput (guidellm)

Metric PR-Fused PR-Unfused PR-Default Main-Fused Main-Default
Output tok/s (Mean) 824.3 830.9 836.0 834.6 830.6
Total tok/s (Mean) 1459.8 1471.4 1476.3 1474.6 1470.8

Serving Throughput (vllm bench serve)

Metric PR-Fused PR-Unfused PR-Default Main-Fused
Output tok/s 13,535 14,556 14,704 14,530
Total tok/s 26,934 28,967 29,261 28,915
Mean TTFT (ms) 858.5 604.5 554.4 619.6
Mean TPOT (ms) 28.54 28.57 28.55 28.53
P99 ITL (ms) 67.74 70.77 58.61 59.11

@sparkecho
Copy link
Author

@ProExpertProg @LopezCastroRoberto Apologies for the long silence. I’ve been tied up with other commitments and lacked a proper development environment, which put my optimization work on hold. Now that my environment is set up, I’d like to pick this back up. Do you still think it’s worth pursuing this feature?

If we proceed, I have a technical concern: reusing the grid and block configurations from nvfp4_quant_kernels.cu implies using a 2D grid. I’ve attempted this implementation, but it resulted in degraded performance.
Unlike element-wise kernels like scaled_fp4_quant_sm1xxa or silu_and_mul_nvfp4_quant_sm1xxa, RMSNorm requires calculating the variance across all elements first. A 2D grid splits a single row across multiple blocks, which forces cross-block reduction and synchronization. This increases memory traffic and breaks the 'single-block-per-row' fast path.

@ProExpertProg
Copy link
Collaborator

Yes, let's pursue this. If you prefer we can merge the flashinfer rms-fp4 kernel first, and then tune yours until it's better. Or merge this kernel first and then flashinfer. I don't have a preference, as long as 1. we get this fusion merged (with a speedup) and 2. we are using the fastest kernel we have available.

@sparkecho
Copy link
Author

sparkecho commented Mar 1, 2026 via email

@ProExpertProg
Copy link
Collaborator

Yep, that sounds great!

@LopezCastroRoberto
Copy link
Contributor

LopezCastroRoberto commented Mar 1, 2026

We’ve been doing some refactoring to eliminate duplicated helper code for vectorized instructions. Could you please rebase your branch and use the shared helper file instead? #35105

@ProExpertProg
Copy link
Collaborator

@sparkecho any update on the flashinfer kernel integration? Feel free to open a new PR if you want

@sparkecho
Copy link
Author

@ProExpertProg Thanks for checking in! Progress has been a bit slower than expected on my end, but I am currently focused on the FlashInfer kernel integration. I’m making steady headway and expect to open a new PR (or submit a fresh update) within the next two days. I really appreciate your patience and the nudge!

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Mar 6, 2026

Sounds good, looking forward to it! This is one of the few remaining obvious ones for for models like deepseek

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia performance Performance-related issues

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants