diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 7914265dd3b1..d90e01773474 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -369,7 +369,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( symm_output = torch.empty( hidden_states.shape[0], hidden_states.shape[1], - dtype=torch.bfloat16, + dtype=hidden_states.dtype, device=hidden_states.device, ) @@ -449,9 +449,11 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( tune_max_num_tokens=next_power_of_2(a_q.shape[0]), fp8_quantization_type=int(fp8_quantization_type), ) + # TODO: Once https://github.com/flashinfer-ai/flashinfer/issues/2703 is fixed, pass output to moe kernel and remove this copy. symm_output.copy_(output) output = symm_output else: + assert TopKOutputChecker.format_is_bypassed(topk_output) assert quant_info.w13_input_scale is not None assert quant_info.output1_scales_scalar is not None assert quant_info.output1_scales_gate_scalar is not None diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 98716292297e..b505957931f5 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1352,7 +1352,10 @@ def process_weights_after_loading(self, layer: Module) -> None: self.process_weights_hip_scale_padding(layer) # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled - if get_moe_runner_backend().is_flashinfer_trtllm(): + if ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ): from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( align_fp8_moe_weights_for_flashinfer_trtllm, ) @@ -1600,7 +1603,8 @@ def apply( local_num_experts=num_local_experts, intermediate_size=layer.w2_weight.shape[2], routing_method_type=int( - getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3) + getattr(layer, "routing_method_type", None) + or RoutingMethodType.DeepSeekV3 ), block_quant=self.block_quant, use_mxfp8=getattr(self.quant_config, "use_mxfp8", False), diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fc9afafac90b..3194c60ed8d0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1978,6 +1978,8 @@ def _should_run_flashinfer_autotune(self) -> bool: if backend_str not in [ "flashinfer_trtllm", + # TODO: Enable for flashinfer_trtllm_routed once https://github.com/flashinfer-ai/flashinfer/issues/2749 is fixed. + # "flashinfer_trtllm_routed", "flashinfer_mxfp4", # TODO: flashinfer_cutlass will cause some flashinfer compilation errors. To be fixed. # "flashinfer_cutlass",