Skip to content

Conversation

@kaixih
Copy link
Collaborator

@kaixih kaixih commented Aug 24, 2025

This PR further optimizes silu_and_mul_scaled_fp4_grouped_quant for DeepEP low-latency scenarios. When the mask contains small values, the kernel achieves up to 82× speedup over the previous version. The change also improves performance for non-masked NVFP4 expert quantization (scaled_fp4_experts_quant) and NVFP4 group quantization (scaled_fp4_grouped_quant), delivering about 1.5× gains.

The new design changes how threads are assigned in expert-aware kernels. Previously, all threads started with the first expert and then strided to the next, requiring extra computation to determine which expert they were handling. This needs frequent access to the offset tensors, hence the earlier optimization like here. Now, threads are evenly partitioned across experts, and each thread only processes its assigned expert. This (1) removes the need to recompute expert indices or reaccess the offset tensor, and (2) makes early exit with masking straightforward.

cc. @wenscarl @fzyzcjy @kushanam

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the performance of the silu_and_mul_scaled_fp4_grouped_quant operation, particularly for low-latency deep learning applications. The core change involves the introduction of a new CUDA kernel, cvt_fp16_to_fp4_masked, which efficiently handles the conversion from FP16 to FP4 while incorporating the SiLU and multiply operations when a mask is present. The system now intelligently selects between this new masked kernel and the existing unmasked kernel based on the input parameters, ensuring a more optimized execution path for relevant scenarios.

Highlights

  • New CUDA Kernel for Masked Operations: Introduced cvt_fp16_to_fp4_masked kernel to specifically handle scenarios where a mask is provided, integrating the silu_and_mul operation directly within the conversion process from FP16 to FP4.
  • Conditional Kernel Dispatch: The quant_impl function now intelligently dispatches to either the existing cvt_fp16_to_fp4 kernel or the new cvt_fp16_to_fp4_masked kernel based on whether a mask is present, ensuring optimized execution paths.
  • Performance Optimization for Low Latency: The changes are targeted at improving the performance of silu_and_mul_scaled_fp4_grouped_quant for deep learning low-latency use cases, particularly when dealing with masked inputs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@kaixih kaixih marked this pull request as draft August 24, 2025 08:02
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel cvt_fp16_to_fp4_masked to optimize the silu_and_mul_scaled_fp4_grouped_quant operation, particularly for low-latency scenarios with masking. The new kernel partitions threads among experts to improve parallelism. While this is a solid optimization strategy, I've found a critical issue in the thread-to-expert mapping logic that could lead to out-of-bounds memory access and division-by-zero errors under certain launch configurations. My review includes a suggested fix to ensure the kernel's correctness and stability.

Comment on lines 518 to 413
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for assigning threads to experts has two potential critical issues:

  1. Division by zero: If the total number of threads (gridDim.x * blockDim.x) is less than n_experts, threadsPerExpert will be calculated as zero, causing a division-by-zero error in the following lines when calculating tid_in_expert and expert_idx.
  2. Out-of-bounds access: If the total number of threads is not perfectly divisible by n_experts, expert_idx can be calculated to be >= n_experts for some threads. This would lead to out-of-bounds access on input_offset_by_experts and other expert-indexed arrays later in the kernel.

These issues can lead to kernel crashes. I suggest adding guards to handle these cases safely.

  int threadsPerExpert = gridDim.x * blockDim.x / n_experts;
  if (threadsPerExpert == 0) {
    // Not enough threads for at least one per expert.
    return;
  }
  int expert_idx = tid / threadsPerExpert;
  if (expert_idx >= n_experts) {
    return;
  }
  int tid_in_expert = tid % threadsPerExpert;

@kaixih kaixih marked this pull request as ready for review August 24, 2025 20:23
@kaixih
Copy link
Collaborator Author

kaixih commented Aug 24, 2025

The collected perf can be seen here.

The repro:
python code and its launcher script:

for e in 2 32 128; do
  for m in 1024 2048 4096; do
    for max_m in 8 128 512; do
      python demo_silu_and_mul.py $e $m $m $max_m
    done
  done
done

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 24, 2025

feel free to ping me (maybe on slack if I do not reply here) when this needs a review!

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, could you please test on this shape (which is of most interest)

  • 6 local experts
  • 1024/512/256/128 tokens per local expert (1024 is most interested which corresponds to 768 tok per rank and 48 rank), max_m=4096 (or other big num)
  • hidden dim 7168/2048 (7168 is hidden dim, 2048 is moe expert intermediate dim)

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 25, 2025

I will review once the test case above shows improvement

@kaixih
Copy link
Collaborator Author

kaixih commented Aug 25, 2025

Just pushed some changes. And I noticed that I couldn't make it to replace the existing kernels because I see accuracy drops for dsr1 (Trying to debug it with @pavanimajety since it affects the cutlass fp4 path). Since it is not tightly related to this work, I make the new kernel only available when mask is used at this moment.

As for the kernel improvement for the requested interesting shapes: here.

Basically, we can see ~1.2 to 5x improvement over previous version.

@kaixih
Copy link
Collaborator Author

kaixih commented Aug 26, 2025

Added a benchmark script. Below is the output with varying M and K with masks (max_m=4096). This PR focuses on improve cuda_fused_fp4.

Before:

fp4 quant:
          M       K  triton_fp8  cuda_unfused_fp4  cuda_fused_fp4
0    6144.0  2048.0    0.038205          0.090713        0.036352
1    6144.0  4096.0    0.072557          0.175843        0.054437
2    6144.0  7168.0    0.056954          0.321717        0.127193
3   12288.0  2048.0    0.044199          0.174092        0.056606
4   12288.0  4096.0    0.051270          0.344618        0.100222
5   12288.0  7168.0    0.086383          0.634935        0.208719
6   24576.0  2048.0    0.035896          0.340394        0.085674
7   24576.0  4096.0    0.047692          0.680660        0.156566
8   24576.0  7168.0    0.094515          1.262758        0.274896
9   49152.0  2048.0    0.045434          0.673014        0.140357
10  49152.0  4096.0    0.054944          1.352977        0.263792
11  49152.0  7168.0    0.117085          2.514192        0.487246

After:

fp4 quant:
          M       K  triton_fp8  cuda_unfused_fp4  cuda_fused_fp4
0    6144.0  2048.0    0.041347          0.090724        0.023046
1    6144.0  4096.0    0.068494          0.175784        0.049292
2    6144.0  7168.0    0.100620          0.321734        0.129600
3   12288.0  2048.0    0.048770          0.173978        0.035771
4   12288.0  4096.0    0.055030          0.344595        0.041587
5   12288.0  7168.0    0.075559          0.634973        0.120851
6   24576.0  2048.0    0.043742          0.340365        0.024754
7   24576.0  4096.0    0.060470          0.680595        0.060451
8   24576.0  7168.0    0.086134          1.262491        0.123183
9   49152.0  2048.0    0.040820          0.672932        0.020558
10  49152.0  4096.0    0.048584          1.352822        0.065380
11  49152.0  7168.0    0.069905          2.513879        0.123197

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made a very simple check and the general direction looks reasonable to me, will review in more detail later

btw @Alcanderian do you have interest in having a review as well? I am too busy these days :(

@zhyncs
Copy link
Member

zhyncs commented Aug 27, 2025

@kaixih wait a bit please thanks

@zhyncs
Copy link
Member

zhyncs commented Aug 27, 2025

I need to bump this first https://github.com/sgl-project/sglang/actions/runs/17276109894 please commit after that

@kaixih
Copy link
Collaborator Author

kaixih commented Aug 27, 2025

@zhyncs Sure. Thx for the headsup.

The last commit enabled masked quant. With it, the leading quant performs almost the same as silu-quant-masked in the e2e benchmark with this PR.
before the last commit
image
after the last commit
image

That said, the last commit is more of a “nice-to-have.” Once DeepEP supports swizzled NVFP4 output, we’ll be able to invoke GEMM right after the dispatch.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 28, 2025

btw, when #9199 (#9199 (comment)) passes accuracy, this kernel will be double checked e2e

@kaixih
Copy link
Collaborator Author

kaixih commented Aug 28, 2025

@fzyzcjy This PR (#9199) already uses my changes for the accuracy tests (@wenscarl patched my changes in his internale repo). Without them, execution time doubles or even triples. I think my PR should be treated as a prerequisite. If possible, we should merge it first so that @wenscarl can run his tests more conveniently.

@kaixih kaixih changed the title [NVIDIA] Optimize the silu_and_mul_scaled_fp4_grouped_quant perf [NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf Aug 29, 2025
@kushanam kushanam merged commit 5c34b4f into sgl-project:main Aug 30, 2025
58 of 63 checks passed
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants