Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,12 +1117,16 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
topk_config = topk_output.topk_config

hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
router_logits = router_logits.to(torch.float32)
routing_method_type = self.routing_method_type
assert (
routing_method_type is not None
), "flashinfer trtllm moe nvfp4 backend has not been adapted for the current moe layer, you can set routing_method_type (See definition of RoutingMethodType please) for the moe layer explicitly for a quick adaptation."

# DeepSeekV3 style routing requires float32 router logits,
# see this PR for details: https://github.com/flashinfer-ai/flashinfer/commit/d84e1d560da0a27961c19ca788d96c19cb9dcfb6
if routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

correction_bias = (
None
if topk_config.correction_bias is None
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,6 @@ def _handle_model_specific_adjustments(self):
)
self.disable_overlap_schedule = True
if is_sm100_supported():
self.attention_backend = "triton"
quantization_config = getattr(hf_config, "quantization_config", None)
quant_method = (
quantization_config.get("quant_method")
Expand Down
Loading