Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,6 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):

hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)

router_logits = router_logits.to(torch.float32)

with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
Expand All @@ -1097,9 +1095,25 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
symm_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
)

# Some dtype or value requirements for the flashinfer trtllm kernels.
router_logits = (
router_logits.to(torch.float32)
if self.routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
)
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
if routed_scaling_factor is None:
routed_scaling_factor = 1.0

result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
Expand All @@ -1125,9 +1139,9 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
routed_scaling_factor=routed_scaling_factor,
tile_tokens_dim=None,
routing_method_type=RoutingMethodType.DeepSeekV3,
routing_method_type=self.routing_method_type,
do_finalize=True,
output=symm_output,
)[0]
Expand Down
Loading