Skip to content

[Perf][Kernel] Fuse SiLU+Mul into NVFP4 Expert Quantization for CUTLASS MoE#18612

Open
JackChuang wants to merge 4 commits intosgl-project:mainfrom
bytedance-iaas:horenc/fuse_moe_silu_fp4downproj
Open

[Perf][Kernel] Fuse SiLU+Mul into NVFP4 Expert Quantization for CUTLASS MoE#18612
JackChuang wants to merge 4 commits intosgl-project:mainfrom
bytedance-iaas:horenc/fuse_moe_silu_fp4downproj

Conversation

@JackChuang
Copy link
Contributor

Summary

In the CUTLASS FP4 MoE pipeline, the path between GEMM1 and GEMM2 previously required 3 separate steps: allocate intermediate buffer → silu_and_mul → scaled_fp4_experts_quant. This PR fuses them into a single CUDA kernel silu_and_mul_scaled_fp4_experts_quant_packed, eliminating one intermediate buffer allocation and one extra kernel launch. Inspired by vllm #31832

Before → After this PR

# Before (3 steps, 1 extra buffer)
intermediate = torch.empty((m*topk, k//2), ...)   # alloc
silu_and_mul(c1, intermediate)                      # kernel 1
int_fp4, scales = scaled_fp4_experts_quant(intermediate, ...)  # kernel 2

# After (1 step, no intermediate buffer)
int_fp4, scales = silu_and_mul_scaled_fp4_experts_quant_packed(c1, ...)  # fused kernel

Key Changes

CUDA kernel (nvfp4_expert_quant.cu):

  • Added use_silu_and_mul flag to cvt_fp16_to_fp4 kernels (both low-latency and offset-based variants). Previously SiLU+mul was implicitly tied to mask != nullptr; now it's an independent toggle.
  • New entry function silu_and_mul_scaled_fp4_experts_quant_packed_sm100a — uses expert offsets (not masks) to correctly handle non-uniform token distribution across experts.
  • Input shape (m, 2*k) — gate+up concatenated from GEMM1 output; the kernel reads both halves, applies SiLU(gate)×up, then FP4-quantizes in one pass.

Op registration (common_extension.cc, sgl_kernel_ops.h, nvfp4_quant_entry.cu):

  • Registered silu_and_mul_scaled_fp4_experts_quant_packed as a new torch op.

Python wrapper (gemm.py):

  • silu_and_mul_scaled_fp4_experts_quant_packed() — handles dimension calculation (k = input.shape[1] // 2), output/scale allocation, kernel dispatch, and reinterprets scale output as float8_e4m3fn for GEMM2.

MoE integration (cutlass_moe.py):

  • Replaced the 3-step unfused path with single silu_and_mul_scaled_fp4_experts_quant_packed(c1, ...) call.

Experimental Results

Experimental Setup

HW: GB200*4
Model: nvidia/Qwen3-30B-A3B-NVFP4
Run: 
python3 -m sglang.launch_server --model-path /data06/models/Qwen3-30B-A3B-NVFP4  --trust-remote-code --port 8010 --mem-fraction-static 0.90 --disable-radix-cache

Accuracy & Throughput & Latency Benchmark

Both latency and throughput brings ~5% improvement

$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --port 8010

# Base
Accuracy: 0.910
Invalid: 0.000
Latency: 13.969 s
Output throughput: 1840.902 token/s

# PR
Accuracy: 0.910
Invalid: 0.000
Latency: 13.298 s
Output throughput: 1933.693 token/s

Throughput Benchmark

1.4% gain

python3 -m sglang.bench_serving --backend sglang-oai-chat --base-url http://127.0.0.1:8010 --model /data06/models/Qwen3-30B-A3B-NVFP4 --dataset-name random --seed 5 --random-input-len 3500 --random-output-len 1500 --num-prompts 512

# Baseline
============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     512
Benchmark duration (s):                  33.99
Total input tokens:                      873513
Total input text tokens:                 873513
Total generated tokens:                  384805
Total generated tokens (retokenized):    383474
Request throughput (req/s):              15.06
Input token throughput (tok/s):          25700.15
Output token throughput (tok/s):         11321.58
Peak output token throughput (tok/s):    25225.00
Peak concurrent requests:                512
Total token throughput (tok/s):          37021.74
Concurrency:                             361.26
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   23981.91
Median E2E Latency (ms):                 25144.01
P90 E2E Latency (ms):                    32213.33
P99 E2E Latency (ms):                    33801.63
---------------Time to First Token----------------
Mean TTFT (ms):                          5709.78
Median TTFT (ms):                        5503.59
P99 TTFT (ms):                           10230.33
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          38.04
Median TPOT (ms):                        24.75
P99 TPOT (ms):                           256.45
---------------Inter-Token Latency----------------
Mean ITL (ms):                           24.40
Median ITL (ms):                         18.42
P95 ITL (ms):                            21.63
P99 ITL (ms):                            26.44
Max ITL (ms):                            8364.16
==================================================

# PR
============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     512
Benchmark duration (s):                  33.52
Total input tokens:                      873513
Total input text tokens:                 873513
Total generated tokens:                  384805
Total generated tokens (retokenized):    383470
Request throughput (req/s):              15.28
Input token throughput (tok/s):          26062.39
Output token throughput (tok/s):         11481.16
Peak output token throughput (tok/s):    24969.00
Peak concurrent requests:                512
Total token throughput (tok/s):          37543.55
Concurrency:                             362.60
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   23736.42
Median E2E Latency (ms):                 25363.32
P90 E2E Latency (ms):                    31750.16
P99 E2E Latency (ms):                    33325.78
---------------Time to First Token----------------
Mean TTFT (ms):                          5464.31
Median TTFT (ms):                        5269.75
P99 TTFT (ms):                           9846.33
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          37.91
Median TPOT (ms):                        24.83
P99 TPOT (ms):                           252.14
---------------Inter-Token Latency----------------
Mean ITL (ms):                           24.40
Median ITL (ms):                         18.83
P95 ITL (ms):                            22.13
P99 ITL (ms):                            26.22
Max ITL (ms):                            8306.58
==================================================

Latency Benchmark

2% gain

python3 -m sglang.bench_serving --backend sglang-oai-chat --base-url http://127.0.0.1:8010 --model /data06/models/Qwen3-30B-A3B-NVFP4 --dataset-name random --seed 5 --random-input-len 100 --random-output-len 100 --num-prompts 8


# baseline 
============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     8
Benchmark duration (s):                  0.67
Total input tokens:                      432
Total input text tokens:                 432
Total generated tokens:                  376
Total generated tokens (retokenized):    376
Request throughput (req/s):              11.95
Input token throughput (tok/s):          645.33
Output token throughput (tok/s):         561.67
Peak output token throughput (tok/s):    376.00
Peak concurrent requests:                8
Total token throughput (tok/s):          1207.00
Concurrency:                             5.20
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   435.32
Median E2E Latency (ms):                 408.65
P90 E2E Latency (ms):                    659.82
P99 E2E Latency (ms):                    660.35
---------------Time to First Token----------------
Mean TTFT (ms):                          129.90
Median TTFT (ms):                        129.97
P99 TTFT (ms):                           130.63
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.57
Median TPOT (ms):                        6.63
P99 TPOT (ms):                           6.84
---------------Inter-Token Latency----------------
Mean ITL (ms):                           6.64
Median ITL (ms):                         6.58
P95 ITL (ms):                            7.31
P99 ITL (ms):                            8.71
Max ITL (ms):                            8.99
==================================================

# PR
============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     8
Benchmark duration (s):                  0.66
Total input tokens:                      432
Total input text tokens:                 432
Total generated tokens:                  376
Total generated tokens (retokenized):    376
Request throughput (req/s):              12.21
Input token throughput (tok/s):          659.11
Output token throughput (tok/s):         573.67
Peak output token throughput (tok/s):    376.00
Peak concurrent requests:                8
Total token throughput (tok/s):          1232.78
Concurrency:                             5.20
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   426.12
Median E2E Latency (ms):                 400.27
P90 E2E Latency (ms):                    646.16
P99 E2E Latency (ms):                    646.77
---------------Time to First Token----------------
Mean TTFT (ms):                          125.13
Median TTFT (ms):                        125.19
P99 TTFT (ms):                           125.78
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.48
Median TPOT (ms):                        6.52
P99 TPOT (ms):                           6.78
---------------Inter-Token Latency----------------
Mean ITL (ms):                           6.54
Median ITL (ms):                         6.49
P95 ITL (ms):                            7.20
P99 ITL (ms):                            7.53
Max ITL (ms):                            7.54
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @JackChuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the CUTLASS FP4 Mixture-of-Experts (MoE) pipeline by fusing three previously separate operations—intermediate buffer allocation, SiLU+Mul activation, and scaled FP4 expert quantization—into a single, highly efficient CUDA kernel. This fusion reduces memory overhead and kernel launch costs, leading to notable performance gains, specifically around 5% improvement in latency and throughput for MoE models.

Highlights

  • Performance Optimization: Introduced a new fused CUDA kernel, silu_and_mul_scaled_fp4_experts_quant_packed, which combines SiLU+Mul activation and FP4 quantization into a single operation.
  • Resource Efficiency: Eliminated one intermediate buffer allocation and one extra kernel launch in the CUTLASS FP4 MoE pipeline, reducing memory overhead and computational steps.
  • Kernel Flexibility: Modified existing cvt_fp16_to_fp4 kernels to explicitly support a use_silu_and_mul flag, making the SiLU+Mul operation an independent toggle rather than implicitly tied to mask usage.
  • MoE Integration: Implemented a new entry function that leverages expert offsets to correctly handle non-uniform token distribution across experts, improving the robustness of the quantization process.
  • API Extension: Registered the new fused operation as a PyTorch operator and provided a Python wrapper that manages dimension calculations, output allocation, kernel dispatch, and reinterpretation of scale outputs.
  • Benchmark Results: Achieved approximately 5% improvement in both latency and throughput, and a 1.4% gain in overall throughput, as demonstrated by experimental benchmarks on a Qwen3-30B-A3B-NVFP4 model.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/moe/cutlass_moe.py
    • Imported the new silu_and_mul_scaled_fp4_experts_quant_packed function.
    • Replaced the three-step sequence of intermediate buffer allocation, silu_and_mul, and scaled_fp4_experts_quant with a single call to the new fused function.
  • sgl-kernel/csrc/common_extension.cc
    • Registered the new silu_and_mul_scaled_fp4_experts_quant_packed function as a PyTorch CUDA operation.
  • sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
    • Modified cvt_fp16_to_fp4 functions to include a use_silu_and_mul boolean parameter.
    • Adjusted actualColsPerRow calculation to account for the use_silu_and_mul flag.
    • Extended the condition for applying SiLU+Mul to if (use_mask || use_silu_and_mul).
    • Added a new CUDA kernel silu_and_mul_scaled_fp4_experts_quant_packed_sm100a for fused SiLU+Mul and FP4 quantization.
    • Updated calls to quant_impl to pass the new use_silu_and_mul flag.
  • sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
    • Declared and defined the silu_and_mul_scaled_fp4_experts_quant_packed entry point, dispatching to the _sm100a kernel.
  • sgl-kernel/include/sgl_kernel_ops.h
    • Declared the silu_and_mul_scaled_fp4_experts_quant_packed function.
  • sgl-kernel/python/sgl_kernel/init.py
    • Imported silu_and_mul_scaled_fp4_experts_quant_packed to expose it in the Python package.
  • sgl-kernel/python/sgl_kernel/gemm.py
    • Implemented the Python wrapper function silu_and_mul_scaled_fp4_experts_quant_packed, handling input/output tensor creation, dimension calculations, and reinterpreting output scales as float8_e4m3fn.
Activity
  • The pull request author has provided a detailed summary of the changes, including a clear 'Before -> After' example, key changes across different components (CUDA kernel, op registration, Python wrapper, MoE integration), and comprehensive experimental results demonstrating accuracy, throughput, and latency improvements.
  • The author also included a checklist indicating completion of code formatting, documentation updates, and benchmark results.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by fusing the SiLU+Mul operation into the NVFP4 expert quantization kernel for CUTLASS MoE. This is achieved by adding a new fused CUDA kernel silu_and_mul_scaled_fp4_experts_quant_packed, which eliminates an intermediate buffer and a kernel launch. The changes span across CUDA kernel implementations, PyTorch op registration, and Python wrappers.

My review focuses on code structure and style. I've identified a couple of areas for improvement:

  • Refactoring duplicated code in the new CUDA kernel to improve maintainability.
  • Adhering to Python's import conventions.

Overall, the changes are well-documented and the performance benefits are clearly demonstrated. The implementation appears correct and follows the logic described.

Comment on lines +735 to +741
void silu_and_mul_scaled_fp4_experts_quant_packed_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between this new function silu_and_mul_scaled_fp4_experts_quant_packed_sm100a and the existing scaled_fp4_experts_quant_sm100a function (lines 568-650). Both functions perform nearly identical checks for tensor properties, dimensions, and types.

To improve maintainability and reduce redundancy, consider refactoring the common logic into a shared helper function. This helper could accept a bool use_silu_and_mul parameter to handle the minor differences in logic, such as the calculation of k and the flag passed to quant_impl.

For example, you could create a helper function like this:

void scaled_fp4_experts_quant_sm100a_impl(
    torch::Tensor& output,
    torch::Tensor& output_scale,
    torch::Tensor const& input,
    torch::Tensor const& input_global_scale,
    torch::Tensor const& input_offset_by_experts,
    torch::Tensor const& output_scale_offset_by_experts,
    bool use_silu_and_mul) {
  // ... all common checks and logic ...
  
  auto k = input.size(1);
  if (use_silu_and_mul) {
    TORCH_CHECK(k % 2 == 0, "input last dim must be even (2*k)");
    k /= 2;
  }
  
  // ... more checks ...

  // Call quant_impl with the use_silu_and_mul flag
  if (input.dtype() == at::ScalarType::Half) {
    quant_impl<half>(..., use_silu_and_mul, ...);
  } else if (input.dtype() == at::ScalarType::BFloat16) {
    quant_impl<__nv_bfloat16>(..., use_silu_and_mul, ...);
  }
}

Then, silu_and_mul_scaled_fp4_experts_quant_packed_sm100a and scaled_fp4_experts_quant_sm100a would become simple wrappers calling this helper.

m_numtopk, k_input_doubled = input_tensor.shape
k = k_input_doubled // 2 # Actual feature dim after SiLU+mul reduces 2k -> k

import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import os statement is inside the function. According to the PEP 8 style guide, imports should usually be at the top of the file. Please move this import to the top level to improve code style and consistency.

References
  1. PEP 8 recommends that imports are always put at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 quant LLM Quantization sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant