diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 4924fca1549d..236f836f1404 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1083,8 +1083,6 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states) - router_logits = router_logits.to(torch.float32) - with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): @@ -1097,9 +1095,25 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): symm_output = torch.empty( num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device ) + + # Some dtype or value requirements for the flashinfer trtllm kernels. + router_logits = ( + router_logits.to(torch.float32) + if self.routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(hidden_states.dtype) + ) + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + if routed_scaling_factor is None: + routed_scaling_factor = 1.0 + result = trtllm_fp4_block_scale_moe( routing_logits=router_logits, - routing_bias=topk_config.correction_bias.to(hidden_states.dtype), + routing_bias=correction_bias, hidden_states=hs_fp4, hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(), gemm1_weights=self.gemm1_weights_fp4_shuffled.data, @@ -1125,9 +1139,9 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): intermediate_size=self.intermediate_size_per_partition, local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_num_experts=self.num_local_experts, - routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, + routed_scaling_factor=routed_scaling_factor, tile_tokens_dim=None, - routing_method_type=RoutingMethodType.DeepSeekV3, + routing_method_type=self.routing_method_type, do_finalize=True, output=symm_output, )[0]