-
Notifications
You must be signed in to change notification settings - Fork 5.2k
[Feature] NVFP4 Marlin fallback for non-Blackwell GPUs (SM75+) #19652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc5834f
a30158b
3a0b745
8bf0d2a
f2d2980
2f4ea2a
ea50117
80cfc25
741c505
443788f
51e2675
0c1083b
1aabcea
5f6e09e
97c92c3
c48421f
15fffba
75dd81f
c79a549
99b5714
20aa3d1
6a3cc63
02298cb
4343901
83273b1
81f5f50
e7cdf84
e45a379
4101751
1b0dad2
a57d646
671c73c
eedab50
ca8b695
ef6f6c6
d1bf6a5
d0b7ee5
2ea9789
2d48807
22a66a8
3ea1b89
bd73fc5
133f017
1699078
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,6 +121,8 @@ SGLang supports various environment variables that can be used to configure its | |
| | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | ||
| | `SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2` | Apply per token group quantization kernel with fused silu and mul and masked m | `false` | | ||
| | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | | ||
| | `SGLANG_FORCE_NVFP4_MARLIN` | Force using NVFP4 Marlin fallback kernels even on Blackwell GPUs with native FP4 support | `false` | | ||
| | `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (deprecated) | Select backend for `mm_fp4` on Blackwell GPUs. **DEPRECATED**: Please use `--fp4-gemm-backend` instead. | `` | | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merge conflict
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi. @b8zhong Thanks for the heads-up and sorry for this problems....
SGLANG_FLASHINFER_FP4_GEMM_BACKEND line — removed by #21536, I accidentally kept it. I'll fix it. SGLANG_HICACHE_MAX_PINNED_RATIO — removed by #21884
silu_and_mul moved from sgl_kernel → sglang.jit_kernel.activation (#21766) Missing and not get_fp4_gemm_runner_backend().is_cutlass() guard on the flashinfer path
The remaining 11 files are identical to main. I'll rebase on latest main to resolve all of these cleanly If you approve my plan. Really sorry about that... |
||
| | `SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN` | Quantize q_b_proj from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` | | ||
| | `SGLANG_MOE_NVFP4_DISPATCH` | Use nvfp4 for moe dispatch (on flashinfer_cutlass or flashinfer_cutedsl moe runner backend) | `"false"` | | ||
| | `SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE` | Quantize moe of nextn layer from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` | | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,8 +69,13 @@ class MarlinMoeQuantInfo(MoeQuantInfo): | |
| w13_qzeros: Optional[torch.Tensor] = None | ||
| w2_qzeros: Optional[torch.Tensor] = None | ||
|
|
||
| # Optional | ||
| # FP4 Marlin specific (Optional) | ||
| w13_global_scale: Optional[torch.Tensor] = None | ||
| w2_global_scale: Optional[torch.Tensor] = None | ||
|
|
||
| # EP support (Optional) | ||
| expert_map: Optional[torch.Tensor] = None | ||
| global_num_experts: int = -1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this extra args
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed for Expert Parallelism (EP). Under EP, each rank holds only a subset of experts, so the local weight tensor's expert count E < total model experts. But topk_ids contains global expert IDs, and moe_align_block_size creates buckets indexed by expert ID — it needs the global count to size the output correctly. Without it, global IDs would exceed the local range and cause incorrect routing or out-of-bounds access. When EP is not used, -1 falls back to E (line 125-126 in |
||
|
|
||
|
|
||
| @register_fused_func("none", "marlin") | ||
|
|
@@ -106,6 +111,7 @@ def fused_experts_none_to_marlin( | |
| gating_output=topk_output.router_logits, | ||
| topk_weights=topk_output.topk_weights, | ||
| topk_ids=topk_output.topk_ids, | ||
| global_num_experts=quant_info.global_num_experts, | ||
| expert_map=quant_info.expert_map, | ||
| g_idx1=quant_info.w13_g_idx, | ||
| g_idx2=quant_info.w2_g_idx, | ||
|
|
@@ -118,6 +124,8 @@ def fused_experts_none_to_marlin( | |
| is_k_full=quant_info.is_k_full, | ||
| inplace=runner_config.inplace, | ||
| routed_scaling_factor=runner_config.routed_scaling_factor, | ||
| w1_global_scale=quant_info.w13_global_scale, | ||
| w2_global_scale=quant_info.w2_global_scale, | ||
| ).to(hidden_states.dtype) | ||
|
|
||
| return StandardCombineInput( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have some performance advantage on Blackwell? Or just a normal feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No performance advantage. Blackwell's native FP4 is the default and faster. This env is purely for debugging/testing — e.g., comparing native vs Marlin accuracy, regression testing the Marlin path on Blackwell, or as a workaround if native FP4 has issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For RL you might want NVFP4 weights + BF16 activations (+Lora)