Skip to content
10 changes: 10 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,16 @@ def _weight_loader_impl(
if method.__class__.__name__ == "KTEPWrapperMethod":
method = method.gpu_method

# For flashinfer TRT-LLM BF16 path, process_weights_after_loading reshapes
# expert weights into block layout. During weight update, we must restore
# canonical load-time shapes before copying checkpoint tensors.
if isinstance(method, UnquantizedFusedMoEMethod):
method.maybe_restore_flashinfer_trtllm_bf16_weight_shape_for_load(
layer=self,
param=param,
weight_name=weight_name,
)

loaded_weight = (
loaded_weight.t().contiguous()
if (
Expand Down
120 changes: 88 additions & 32 deletions python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,58 +664,107 @@ def fused_experts_none_to_flashinfer_trtllm_bf16(
dispatch_output: StandardDispatchOutput,
quant_info: FlashInferTrtllmBf16MoeQuantInfo,
runner_config: MoeRunnerConfig,
use_routed_topk: bool = False,
) -> StandardCombineInput:
# lazy import
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker
from sglang.srt.layers.moe.utils import RoutingMethodType

try:
from flashinfer.fused_moe import trtllm_bf16_moe
except ImportError as e:
raise ImportError(
"Can't import trtllm_bf16_moe from flashinfer. "
"Please check flashinfer version to use bf16 with flashinfer_trtllm backend."
) from e
trtllm_bf16_routed_moe = None
trtllm_bf16_moe = None
if use_routed_topk:
try:
from flashinfer.fused_moe import trtllm_bf16_routed_moe
except ImportError as e:
raise ImportError(
"Can't import trtllm_bf16_routed_moe from flashinfer. "
"Please check flashinfer version to use bf16 with flashinfer_trtllm_routed backend."
) from e
else:
try:
from flashinfer.fused_moe import trtllm_bf16_moe
except ImportError as e:
raise ImportError(
"Can't import trtllm_bf16_moe from flashinfer. "
"Please check flashinfer version to use bf16 with flashinfer_trtllm backend."
) from e

assert (
runner_config.activation == "silu"
), "Only silu is supported for flashinfer trtllm moe"
assert (
dispatch_output.topk_output.topk_config.renormalize
), "Renormalize is required for flashinfer trtllm moe"
if not use_routed_topk:
assert (
dispatch_output.topk_output.topk_config.renormalize
), "Renormalize is required for flashinfer trtllm moe"
assert (
runner_config.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer trtllm moe"
assert (
runner_config.is_gated
), "Only gated MoEs are supported for flashinfer trtllm moe"
from sglang.srt.layers.moe.topk import TopKOutputChecker

assert TopKOutputChecker.format_is_bypassed(dispatch_output.topk_output)

hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_config = topk_output.topk_config

with use_symmetric_memory(get_tp_group(), disabled=not is_allocation_symmetric()):
if use_routed_topk:
assert (
runner_config.top_k is not None
), "runner_config.top_k is required for flashinfer_trtllm_routed."
assert TopKOutputChecker.format_is_standard(topk_output)
routing_method_type = runner_config.routing_method_type
if routing_method_type is None:
routing_method_type = RoutingMethodType.Default
elif routing_method_type == RoutingMethodType.DeepSeekV3:
routing_method_type = RoutingMethodType.TopK

# Call the fused kernel
final_hidden_states = trtllm_bf16_moe(
routing_logits=topk_output.router_logits,
routing_bias=topk_config.correction_bias,
hidden_states=hidden_states,
gemm1_weights=quant_info.gemm1_weights,
gemm2_weights=quant_info.gemm2_weights,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=runner_config.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=runner_config.num_local_experts,
routing_method_type=runner_config.routing_method_type,
routed_scaling_factor=runner_config.routed_scaling_factor,
tune_max_num_tokens=next_power_of_2(hidden_states.shape[0]),
)
packed_topk_ids = _pack_topk_for_flashinfer_routed(
topk_ids=topk_output.topk_ids,
topk_weights=topk_output.topk_weights,
)
final_hidden_states = trtllm_bf16_routed_moe(
topk_ids=packed_topk_ids,
hidden_states=hidden_states,
gemm1_weights=quant_info.gemm1_weights,
gemm2_weights=quant_info.gemm2_weights,
num_experts=quant_info.global_num_experts,
top_k=runner_config.top_k,
n_group=None,
topk_group=None,
intermediate_size=runner_config.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=runner_config.num_local_experts,
routing_method_type=routing_method_type,
routed_scaling_factor=(
runner_config.routed_scaling_factor
if runner_config.routed_scaling_factor is not None
else 1.0
),
tune_max_num_tokens=next_power_of_2(hidden_states.shape[0]),
)
else:
assert TopKOutputChecker.format_is_bypassed(topk_output)
topk_config = topk_output.topk_config

# Call the fused kernel
final_hidden_states = trtllm_bf16_moe(
routing_logits=topk_output.router_logits,
routing_bias=topk_config.correction_bias,
hidden_states=hidden_states,
gemm1_weights=quant_info.gemm1_weights,
gemm2_weights=quant_info.gemm2_weights,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=runner_config.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=runner_config.num_local_experts,
routing_method_type=runner_config.routing_method_type,
routed_scaling_factor=runner_config.routed_scaling_factor,
tune_max_num_tokens=next_power_of_2(hidden_states.shape[0]),
)

return StandardCombineInput(hidden_states=final_hidden_states)

Expand Down Expand Up @@ -757,6 +806,13 @@ def fused_experts_none_to_flashinfer_trtllm_routed(
runner_config,
use_routed_topk=True,
)
if isinstance(quant_info, FlashInferTrtllmBf16MoeQuantInfo):
return fused_experts_none_to_flashinfer_trtllm_bf16(
dispatch_output,
quant_info,
runner_config,
use_routed_topk=True,
)
raise TypeError(
f"Unexpected quant_info type for flashinfer_trtllm_routed: {type(quant_info)}"
)
48 changes: 47 additions & 1 deletion python/sglang/srt/layers/quantization/unquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,58 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

return

def maybe_restore_flashinfer_trtllm_bf16_weight_shape_for_load(
self,
layer: torch.nn.Module,
param: torch.nn.Parameter,
weight_name: str,
) -> None:
"""Restore canonical BF16 MoE load shapes before hot weight copy.

The flashinfer TRT-LLM BF16 postprocess reshapes expert weights into
block layout. During weight update, checkpoint tensors are in
canonical layout and need a temporary shape restore for copy.
"""
if not get_moe_runner_backend().is_flashinfer_trtllm_routed():
return

expected_shape = None
if weight_name.endswith(".experts.w13_weight"):
w13_rows = (
2 * layer.intermediate_size_per_partition
if layer.moe_runner_config.is_gated
else layer.intermediate_size_per_partition
)
expected_shape = (layer.num_local_experts, w13_rows, layer.hidden_size)
elif weight_name.endswith(".experts.w2_weight"):
expected_shape = (
layer.num_local_experts,
layer.hidden_size,
layer.intermediate_size_per_partition,
)

if expected_shape is None or tuple(param.data.shape) == expected_shape:
return

expected_numel = expected_shape[0] * expected_shape[1] * expected_shape[2]
if param.data.numel() != expected_numel:
raise RuntimeError(
f"Cannot restore flashinfer TRT-LLM BF16 MoE weight shape for {weight_name}: "
f"current shape={tuple(param.data.shape)}, expected shape={expected_shape}."
)

param.data = param.data.reshape(expected_shape)

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
if self.use_flashinfer_trtllm_moe:
backend = MoeRunnerBackend.FLASHINFER_TRTLLM
backend = (
MoeRunnerBackend.FLASHINFER_TRTLLM_ROUTED
if get_moe_runner_backend().is_flashinfer_trtllm_routed()
else MoeRunnerBackend.FLASHINFER_TRTLLM
)
elif self.use_triton_kernels:
backend = MoeRunnerBackend.TRITON_KERNELS
else:
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,7 +2623,8 @@ def _handle_moe_kernel_config(self):
assert self.quantization in [
"fp8",
"mxfp8",
], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8' or 'mxfp8'."
None,
], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', or bfloat16 (None)."
self.disable_shared_experts_fusion = True
logger.warning(
"FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
popen_launch_server,
)

register_cuda_ci(est_time=500, suite="nightly-4-gpu-b200", nightly=True)
register_cuda_ci(est_time=600, suite="nightly-4-gpu-b200", nightly=True)


class FlashinferTrtllmGenMoeBackendFP8Base:
Expand Down Expand Up @@ -187,5 +187,11 @@ class TestFlashinferTrtllmGenMoeBackendMXFP8Routed(
backend = "flashinfer_trtllm_routed"


class TestFlashinferTrtllmGenMoeBackendBF16Routed(
FlashinferTrtllmGenMoeBackendBF16Base, CustomTestCase
):
backend = "flashinfer_trtllm_routed"


if __name__ == "__main__":
unittest.main()
Loading
Loading