Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _fake_fp8_block_scale_moe(
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
) -> torch.Tensor:
return torch.empty(
hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
Expand Down Expand Up @@ -58,6 +59,7 @@ def trtllm_fp8_block_scale_moe_wrapper(
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
) -> torch.Tensor:
try:
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
Expand Down Expand Up @@ -93,7 +95,15 @@ def trtllm_fp8_block_scale_moe_wrapper(
from flashinfer.fused_moe import Fp8QuantizationType

kwargs["fp8_quantization_type"] = Fp8QuantizationType(fp8_quantization_type)
if num_fused_shared_experts is not None and num_fused_shared_experts > 0:
kwargs["num_fused_shared_experts"] = num_fused_shared_experts

from sglang.srt.utils import print_warning_once

print_warning_once(
f"[trtllm_fp8_block_scale_moe] num_fused_shared_experts={kwargs.get('num_fused_shared_experts', 'NOT SET')}, "
f"top_k={kwargs['top_k']}, num_experts={kwargs['num_experts']}"
)
return trtllm_fp8_block_scale_moe(**kwargs)


Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,9 +1187,6 @@ def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert (
topk_output.topk_config.renormalize
), "Renormalize is required for flashinfer trtllm moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer trtllm moe"
assert (
self.moe_runner_config.is_gated
), "Only gated MoEs are supported for flashinfer trtllm moe"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
else:
assert TopKOutputChecker.format_is_bypassed(topk_output)

_nfse = runner_config.num_fused_shared_experts or 0
output = _trtllm_fp8_block_scale_moe_wrapper(
routing_logits=(
router_logits.to(torch.float32)
Expand All @@ -433,7 +434,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale_inv,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
top_k=topk_config.top_k - _nfse,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=quant_info.intermediate_size,
Expand All @@ -448,6 +449,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
use_shuffled_weight=use_shuffled_weight,
tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
fp8_quantization_type=int(fp8_quantization_type),
num_fused_shared_experts=_nfse if _nfse > 0 else None,
)
symm_output.copy_(output)
output = symm_output
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,9 @@ def apply(
# router logits directly (no separate apply_with_router_logits needed).
# FlashInfer TRT-LLM routed backend consumes SGLang-computed
# top-k ids/weights (packed into int32) instead of router logits.
global_num_experts = int(getattr(layer, "num_experts"))
num_local_experts = int(getattr(layer, "num_local_experts"))
_nfse = int(getattr(layer, "num_fused_shared_experts", 0))
global_num_experts = int(getattr(layer, "num_experts")) - _nfse
num_local_experts = int(getattr(layer, "num_local_experts")) - _nfse
moe_ep_rank = int(getattr(layer, "moe_ep_rank"))

quant_info = FlashInferTrtllmFp8MoeQuantInfo(
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,8 @@ def determine_num_fused_shared_experts(
)
elif get_moe_expert_parallel_world_size() > 1 and (
not _is_hip or torch.cuda.get_device_capability("cuda") < (9, 4)
) and not (
_is_cuda and get_moe_runner_backend().is_flashinfer_trtllm()
):
disable_reason = "Only Deepseek V3/R1 on AMD-platform with capability >= gfx942(MI30x) can use shared experts fusion optimization under expert parallelism."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reason for disabling shared expert fusion is now potentially misleading. With this change, fusion is also enabled for CUDA with the flashinfer_trtllm backend under expert parallelism. The message should be updated to reflect this to avoid confusion for users on other CUDA configurations.

Suggested change
disable_reason = "Only Deepseek V3/R1 on AMD-platform with capability >= gfx942(MI30x) can use shared experts fusion optimization under expert parallelism."
disable_reason = "Shared experts fusion under expert parallelism is only supported on AMD-platform with capability >= gfx942(MI30x) or on CUDA with the flashinfer_trtllm backend."

elif disable_reason is None and (
Expand Down
12 changes: 8 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2630,10 +2630,14 @@ def _handle_moe_kernel_config(self):
"compressed-tensors",
None,
], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', 'modelopt_mixed', 'compressed-tensors', or bfloat16 (None)."
self.disable_shared_experts_fusion = True
logger.warning(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
# FP8 block-scale (fp8, mxfp8) supports fused shared experts via
# num_fused_shared_experts in trtllm_fp8_block_scale_moe; other
# quant types (BF16, FP4, per-tensor FP8) do not.
if self.quantization not in ["fp8", "mxfp8"]:
self.disable_shared_experts_fusion = True
logger.warning(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The warning message is correct but could be more informative. It states that shared expert fusion is disabled but doesn't explain why. The code comment explains the reason well; incorporating that into the log message would improve user experience.

Suggested change
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
"FlashInfer TRTLLM MoE is enabled, but fused shared experts are only supported for fp8/mxfp8 quantization. --disable-shared-experts-fusion is automatically set."

)

if self.moe_runner_backend == "flashinfer_trtllm_routed":
assert self.quantization in [
Expand Down
Loading