Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
symm_output = torch.empty(
hidden_states.shape[0],
hidden_states.shape[1],
dtype=torch.bfloat16,
dtype=hidden_states.dtype,
device=hidden_states.device,
)

Expand Down Expand Up @@ -449,9 +449,11 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
fp8_quantization_type=int(fp8_quantization_type),
)
# TODO: Once https://github.com/flashinfer-ai/flashinfer/issues/2703 is fixed, pass output to moe kernel and remove this copy.
symm_output.copy_(output)
output = symm_output
else:
assert TopKOutputChecker.format_is_bypassed(topk_output)
assert quant_info.w13_input_scale is not None
assert quant_info.output1_scales_scalar is not None
assert quant_info.output1_scales_gate_scalar is not None
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
self.process_weights_hip_scale_padding(layer)

# Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
if get_moe_runner_backend().is_flashinfer_trtllm():
if (
get_moe_runner_backend().is_flashinfer_trtllm()
or get_moe_runner_backend().is_flashinfer_trtllm_routed()
):
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
align_fp8_moe_weights_for_flashinfer_trtllm,
)
Expand Down Expand Up @@ -1600,7 +1603,8 @@ def apply(
local_num_experts=num_local_experts,
intermediate_size=layer.w2_weight.shape[2],
routing_method_type=int(
getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3)
getattr(layer, "routing_method_type", None)
or RoutingMethodType.DeepSeekV3
),
block_quant=self.block_quant,
use_mxfp8=getattr(self.quant_config, "use_mxfp8", False),
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1978,6 +1978,8 @@ def _should_run_flashinfer_autotune(self) -> bool:

if backend_str not in [
"flashinfer_trtllm",
# TODO: Enable for flashinfer_trtllm_routed once https://github.com/flashinfer-ai/flashinfer/issues/2749 is fixed.
# "flashinfer_trtllm_routed",
"flashinfer_mxfp4",
# TODO: flashinfer_cutlass will cause some flashinfer compilation errors. To be fixed.
# "flashinfer_cutlass",
Expand Down
Loading