Skip to content

Conversation

@wenscarl
Copy link
Collaborator

@wenscarl wenscarl commented Aug 14, 2025

@kaixih @kushanam @fzyzcjy

Motivation

Add grouped_gemm_nt_masked from flashinfer to support nvfp4 MoE. This PR exposes 2 APIs: flashinfer_cutedsl_grouped_gemm_nt_masked and flashinfer_cutedsl_grouped_gemm_nt_masked.
Depends on 9200
The next step is to integrate into EpMoE.

Modifications

Accuracy Tests

SGLANG_DEEPEP_BF16_DISPATCH=true python3 -m sglang.launch_server \
  --model-path nvidia/DeepSeek-R1-0528-FP4 \
  --trust-remote-code \
  --disable-radix-cache \
  --max-running-requests 256 \
  --chunked-prefill-size 1024 \
  --mem-fraction-static 0.89 \
  --max-prefill-tokens 16384 \
  --disable-cuda-graph \
  --tp 8 \
  --dp 8 \
  --enable-dp-attention \
  --load-format dummy \
  --enable-ep-moe \
  --quantization modelopt_fp4 \
  --enable-flashinfer-cutedsl-moe \
  --enable-deepep-moe --deepep-mode low_latency 
python3 benchmark/gsm8k/bench_sglang.py   --num-questions 256   --parallel 32   --num-shots 8

Accuracy: 0.980
Invalid: 0.000
Latency: 288.874 s
Output throughput: 93.390 token/s

math-500 data-set:
this PR

64k output-len
nemo-run_1/0 --------------------------------------- math-500 --------------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 500         | 5496       | 2760        | 98.20%           | 0.00%    
nemo-run_1/0 

Benchmarking and Profiling

python3 -m sglang.bench_serving \
  --model nvidia/DeepSeek-R1-0528-FP4 \
  --dataset-name random \
  --backend sglang-oai \
  --random-range-ratio 1 \
  --random-input-len 1024 \
  --random-output-len 1024 \
  --max-concurrency 256 \
  --num-prompts 512 \
  --base-url http://127.0.0.1:30000
This PR

============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    inf       
Max request concurrency:                 256       
Successful requests:                     512       
Benchmark duration (s):                  93.66     
Total input tokens:                      524288    
Total generated tokens:                  524288    
Total generated tokens (retokenized):    521116    
Request throughput (req/s):              5.47      
Input token throughput (tok/s):          5597.82   
Output token throughput (tok/s):         5597.82   
Total token throughput (tok/s):          11195.64  
Concurrency:                             255.85    
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   46802.25  
Median E2E Latency (ms):                 46808.40  
---------------Time to First Token----------------
Mean TTFT (ms):                          9044.15   
Median TTFT (ms):                        10049.01  
P99 TTFT (ms):                           18816.81  
---------------Inter-Token Latency----------------
Mean ITL (ms):                           37.10     
Median ITL (ms):                         28.44     
P95 ITL (ms):                            32.49     
P99 ITL (ms):                            34.47     
Max ITL (ms):                            18177.41  
==================================================


vs.
deepgemm fp8:
============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    inf       
Max request concurrency:                 256       
Successful requests:                     512       
Benchmark duration (s):                  106.86    
Total input tokens:                      524288    
Total generated tokens:                  524288    
Total generated tokens (retokenized):    507144    
Request throughput (req/s):              4.79      
Input token throughput (tok/s):          4906.23   
Output token throughput (tok/s):         4906.23   
Total token throughput (tok/s):          9812.46   
Concurrency:                             255.83    
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   53394.86  
Median E2E Latency (ms):                 53389.35  
---------------Time to First Token----------------
Mean TTFT (ms):                          10008.84  
Median TTFT (ms):                        10020.20  
P99 TTFT (ms):                           18567.76  
---------------Inter-Token Latency----------------
Mean ITL (ms):                           43.44     
Median ITL (ms):                         34.11     
P95 ITL (ms):                            39.43     
P99 ITL (ms):                            69.28     
Max ITL (ms):                            17688.63  
==================================================

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @wenscarl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates support for Flashinfer's CuteDSL masked grouped GEMM into the Mixture-of-Experts (MoE) implementation, aiming to provide a low-latency backend. It introduces a new configuration flag, updates the MoE layer to conditionally use this new path, and ensures compatibility with modelopt_fp4 quantization. Additionally, a new test utility for masked grouped GEMM operations is included.

Highlights

  • New MoE Backend Option: Introduces --enable-flashinfer-cutedsl-moe as a server argument to activate the Flashinfer CuteDSL MoE implementation, targeting low-latency scenarios.
  • MoE Layer Integration: The ep_moe/layer.py is updated to conditionally dispatch to a new forward_flashinfer_masked method when the CuteDSL MoE is enabled and DeepEP low-latency mode is active.
  • Quantization Requirement: Enforces modelopt_fp4 quantization as a prerequisite for utilizing the Flashinfer CuteDSL MoE backend.
  • Masked Grouped GEMM Test Utility: A new file w4a4_bf16_masked.py is added, containing a reference implementation and pytest-based correctness tests for masked grouped GEMM operations, which are fundamental to the CuteDSL functionality.
  • DeepSeekV2 Model Update: The DeepSeekV2 model's initialization is modified to enable the Flashinfer CuteDSL MoE if the corresponding global server argument is set.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Flashinfer's CuteDSL masked group GEMM for MoE layers. The changes span across server arguments, model implementation, and utility functions. My review has identified some critical issues, including an incomplete placeholder function that will cause runtime errors and the disabling of package version checks, which poses a significant risk. There are also several medium-severity issues related to code duplication, code cleanliness, and TODOs that should be addressed before merging.

Comment on lines 663 to 665
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function forward_flashinfer_masked appears to be a placeholder. The call to flashinfer.cute_dsl.grouped_gemm_nt_masked() is empty, which will cause a TypeError at runtime. This function needs to be fully implemented before this PR can be merged.

Comment on lines 862 to 874
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of assert_pkg_version has been commented out and replaced with pass. This effectively disables all package version checks, which could lead to hard-to-debug runtime errors if incompatible package versions are used. This change seems risky for a production environment. Was this intended for temporary debugging? If so, it should be reverted before merging.

Suggested change
pass
# try:
# installed_version = version(pkg)
# if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
# raise Exception(
# f"{pkg} is installed with version {installed_version}, which "
# f"is less than the minimum required version {min_version}. " + message
# )
# except PackageNotFoundError:
# raise Exception(
# f"{pkg} with minimum required version {min_version} is not installed. "
# + message
# )
try:
installed_version = version(pkg)
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
raise Exception(
f"{pkg} is installed with version {installed_version}, which "
f"is less than the minimum required version {min_version}. " + message
)
except PackageNotFoundError:
raise Exception(
f"{pkg} with minimum required version {min_version} is not installed. "
+ message
)

Comment on lines 471 to 473
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment and the commented-out assertion suggest that this part of the code is either incomplete or the logic is not fully resolved. It would be great to clarify the purpose of the TODO and either remove or implement the assertion.

Comment on lines 508 to 500
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use_fp8 parameter is now hardcoded to False, and the original use_fp8=True is commented out with a TODO. This suggests the new CuteDSL path may not support FP8. Could you clarify if this is a temporary change for this PR and if there are plans to re-enable FP8 support?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new file appears to be a test file for masked grouped gemm.

  1. It seems to be in the wrong directory. Test files are usually located in a test/ directory.
  2. The filename w4a4_bf16_masked.py suggests it's for W4A4 quantization, but the content is a general correctness test for grouped GEMM.
  3. There is a large block of commented-out code at the top of the file. This should be removed before merging.

Could you please move this file to an appropriate test directory, rename it to reflect its purpose (e.g., test_grouped_gemm.py), and remove the dead code?

Comment on lines 535 to 543
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation logic for enable_flashinfer_cutedsl_moe is nearly identical to the logic for enable_flashinfer_cutlass_moe on lines 520-528. To improve maintainability and avoid code duplication, consider refactoring this into a shared check.

@wenscarl wenscarl force-pushed the flashinfer_cutedsl_grp_gemm branch from 4319c6c to 32bf3e4 Compare August 17, 2025 05:23
@wenscarl wenscarl marked this pull request as ready for review August 17, 2025 05:24
@wenscarl wenscarl changed the title Support Flashinfer CuteDSL masked group gemm [NVIDIA] [2/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked Aug 17, 2025
Copy link
Collaborator

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the prompt work! I’ve left some comments mainly about the behavior and scope of the MoE function.

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM since this is again only temporary work and will be refined and fused later (and thus there is no need for very high code detail quality). Only some optional nits

@wenscarl wenscarl requested review from fzyzcjy and kaixih August 17, 2025 20:28
@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 17, 2025

LGTM by directly looking at the code (i.e. dnk accuracy etc), I think maybe we can move to the next part (the e2e integration) to know whether everything here is ok or there is some accuracy bug.

@wenscarl wenscarl force-pushed the flashinfer_cutedsl_grp_gemm branch 2 times, most recently from 4e314d3 to 3cea8af Compare August 20, 2025 19:49
@kushanam kushanam mentioned this pull request Aug 21, 2025
9 tasks
@wenscarl wenscarl force-pushed the flashinfer_cutedsl_grp_gemm branch 2 times, most recently from 362c0f1 to 325031e Compare August 23, 2025 21:20
@wenscarl wenscarl force-pushed the flashinfer_cutedsl_grp_gemm branch from 4d8f43e to 4cac99f Compare September 8, 2025 16:29
@wenscarl wenscarl requested review from fzyzcjy and kaixih September 9, 2025 02:35
@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 9, 2025

looks great, could you please resolve the conflicts and retest on latest main

@wenscarl wenscarl requested a review from fzyzcjy September 9, 2025 16:18
@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 9, 2025

could you please run pre-commit to fix the lint

image

and could you please run that newly added test locally and paste results

and we need to wait a bit until main ci is green (it was green yesterday)

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 10, 2025

main seems better now, could you please handle the lint

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 10, 2025

after CI is green + paste your test results using latest code on gpqa-diamond and math-500 I think it can be merged

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 10, 2025

hmm ci does not have deepep...
maybe temporarily remove this test and we will add it back later

image

@wenscarl
Copy link
Collaborator Author

math-500

nemo-run_1/0   import pynvml  # type: ignore[import]
nemo-run_1/0 --------------------------------------- math-500 --------------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 500         | 5469       | 1317        | 98.80%           | 0.00%    
nemo-run_1/0 
nemo-run_1/0 

gpqa-diamond:

nemo-run_1/0   import pynvml  # type: ignore[import]
nemo-run_1/0 ----------------------------------------- gpqa ----------------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 198         | 8651       | 1476        | 76.77%           | 1.01%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 -------------------------------- gpqa-Physics (general) -------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 19          | 6351       | 520         | 89.47%           | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 -------------------------------- gpqa-Organic Chemistry -------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 72          | 11050      | 807         | 59.72%           | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 -------------------------------- gpqa-Quantum Mechanics -------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 25          | 6494       | 545         | 100.00%          | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ------------------------- gpqa-Electromagnetism and Photonics -------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 6           | 4131       | 147         | 100.00%          | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 -------------------------- gpqa-High-energy particle physics --------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 14          | 11166      | 1476        | 78.57%           | 7.14%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ------------------------------------ gpqa-Genetics ------------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 4           | 5617       | 316         | 25.00%           | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ---------------------------------- gpqa-Astrophysics ----------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 13          | 12167      | 1476        | 92.31%           | 7.69%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 -------------------------------- gpqa-Molecular Biology -------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 15          | 4739       | 658         | 73.33%           | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ------------------------------- gpqa-Chemistry (general) ------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 20          | 7106       | 497         | 85.00%           | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ----------------------------- gpqa-Relativistic Mechanics -----------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 7           | 6573       | 333         | 85.71%           | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ------------------------------- gpqa-Inorganic Chemistry ------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 1           | 8732       | 288         | 100.00%          | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ------------------------------ gpqa-Optics and Acoustics ------------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 1           | 3192       | 98          | 100.00%          | 0.00%    
nemo-run_1/0 
nemo-run_1/0 
nemo-run_1/0 ---------------------------- gpqa-Condensed Matter Physics ----------------------------
nemo-run_1/0 evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
nemo-run_1/0 pass@1          | 1           | 1324       | 42          | 100.00%          | 0.00%    

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 10, 2025

hmm gpqa looks not good... could you please repeat at least 8-16 times and show all values (just the average value, i.e. the first row), and ensure generate 32k tokens.

cc @kaixih I think you have modified my script and make it work on b200 or other cases - could you please share the script with @wenscarl .

@wenscarl
Copy link
Collaborator Author

wenscarl commented Sep 10, 2025

64k gpqa:
pass@1 | 198 | 8530 | 1489 | 81.31% | 0.51%
pass@1 | 198 | 8005 | 656 | 81.82% | 0.00%
pass@1 | 198 | 7883 | 647 | 83.84% | 0.00%
pass@1 | 198 | 8482 | 1497 | 80.30% | 0.51%

32k
pass@1 | 198 | 7934 | 669 | 81.31% | 0.00%
pass@1 | 198 | 8205 | 777 | 81.31% | 0.00%

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 10, 2025

great

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 10, 2025

cc @zhyncs this looks good to merge

@zhyncs zhyncs merged commit 3df05f4 into sgl-project:main Sep 12, 2025
99 of 112 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants