-
Notifications
You must be signed in to change notification settings - Fork 3.4k
FlashInfer NVFP4 MoE with EP & 2-stream shared expert #7327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a84e7cd
Add flashinfer cutlass MoE
trevor-m 5f6ac95
Fix accuracy issue and clean up args
trevor-m 6659b9b
Merge branch 'main' into flashinfer-cutlass-moe
Alcanderian 50830e1
[fix] fix tp
Alcanderian ad02472
[opt] dual stream
Alcanderian e1a8dfb
lint
Alcanderian a2d918b
[opt] pdl
Alcanderian 6a8c386
opt
Alcanderian 64459c3
update threshold
Alcanderian dafcc94
Merge branch 'main' into flashinfer-cutlass-moe
zhyncs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,11 +29,17 @@ | |
| requantize_with_max_scale, | ||
| ) | ||
| from sglang.srt.layers.radix_attention import RadixAttention | ||
| from sglang.srt.utils import is_cuda | ||
| from sglang.srt.utils import get_bool_env_var, is_cuda | ||
|
|
||
| if is_cuda(): | ||
| from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant | ||
|
|
||
| try: | ||
| from flashinfer import fp4_quantize as fp4_quantize | ||
| from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe | ||
| except ImportError: | ||
| flashinfer_cutlass_fused_moe = None | ||
|
|
||
| # Initialize logger for the module | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -521,6 +527,7 @@ def __init__(self, quant_config: ModelOptFp4Config): | |
| " quantization. Please use Blackwell and" | ||
| " above." | ||
| ) | ||
| self.enable_flashinfer_moe = False | ||
|
|
||
| def create_weights( | ||
| self, | ||
|
|
@@ -727,11 +734,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| layer.cutlass_moe_params = CutlassMoEParams( | ||
| CutlassMoEType.BlockscaledFP4, | ||
| device, | ||
| num_experts=layer.num_experts, | ||
| num_experts=layer.num_experts, # global num experts | ||
| intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n | ||
| hidden_size=layer.w13_weight.shape[2] * 2, | ||
| ) # k | ||
|
|
||
| @property | ||
| def load_up_proj_weight_first(self) -> bool: | ||
| # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 | ||
| return self.enable_flashinfer_moe | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
|
|
@@ -750,6 +762,10 @@ def apply( | |
| inplace: bool = True, | ||
| no_combine: bool = False, | ||
| routed_scaling_factor: Optional[float] = None, | ||
| ep_rank: Optional[int] = None, | ||
| ep_size: Optional[int] = None, | ||
| tp_rank: Optional[int] = None, | ||
| tp_size: Optional[int] = None, | ||
| ) -> torch.Tensor: | ||
|
|
||
| assert activation == "silu", "Only SiLU activation is supported." | ||
|
|
@@ -771,6 +787,46 @@ def apply( | |
| routed_scaling_factor=routed_scaling_factor, | ||
| ) | ||
|
|
||
| if self.enable_flashinfer_moe: | ||
| assert ( | ||
| not apply_router_weight_on_input | ||
| ), "apply_router_weight_on_input is not supported for Flashinfer" | ||
| a1_gs = torch.min(layer.w13_input_scale_quant) | ||
| a2_gs = torch.min(layer.w2_input_scale_quant) | ||
Alcanderian marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| w1_blockscale = layer.w13_blockscale_swizzled | ||
| w2_blockscale = layer.w2_blockscale_swizzled | ||
| g1_alphas = layer.g1_alphas | ||
| g2_alphas = layer.g2_alphas | ||
|
|
||
| quant_scales = [ | ||
| a1_gs, | ||
| w1_blockscale.view(torch.int32), | ||
| g1_alphas, | ||
| a2_gs, | ||
| w2_blockscale.view(torch.int32), | ||
| g2_alphas, | ||
| ] | ||
| # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision | ||
| # and fp4 quantized weights loaded from the checkpoint | ||
| out_dtype = x.dtype | ||
| output = x if inplace else torch.zeros_like(x) | ||
Alcanderian marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x, x_sf = fp4_quantize(x, a1_gs) | ||
| output = flashinfer_cutlass_fused_moe( | ||
| x, | ||
| topk_ids.to(torch.int), | ||
| topk_weights, | ||
| layer.w13_weight.view(torch.long), | ||
| layer.w2_weight.view(torch.long), | ||
| out_dtype, | ||
| quant_scales=quant_scales, | ||
| input_sf=x_sf, | ||
| ep_size=ep_size, | ||
| ep_rank=ep_rank, | ||
| tp_size=tp_size, | ||
| tp_rank=tp_rank, | ||
|
||
| ) | ||
| return output[0] | ||
|
|
||
| from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 | ||
|
|
||
| return cutlass_moe_fp4( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will need to handle MTP which is not quantized. Can we just disable flashinfer_moe and ep_moe for MTP module if quant_config is None (context)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/sgl-project/sglang/pull/7376/files#r2160016616