Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,15 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
w13_weight.size(0), # num_experts
)

# Set flashinfer parameters
# Set flashinfer parameters in-place
copy_or_rebind_param(layer, "w13_weight", gemm1_weights_fp4_shuffled.contiguous())
copy_or_rebind_param(layer, "w2_weight", gemm2_weights_fp4_shuffled.contiguous())
copy_or_rebind_param(
layer, "gemm1_weights_fp4_shuffled", gemm1_weights_fp4_shuffled
layer, "w13_weight_scale", gemm1_scales_fp4_shuffled.contiguous()
)
copy_or_rebind_param(
layer, "gemm2_weights_fp4_shuffled", gemm2_weights_fp4_shuffled
layer, "w2_weight_scale", gemm2_scales_fp4_shuffled.contiguous()
)
copy_or_rebind_param(layer, "gemm1_scales_fp4_shuffled", gemm1_scales_fp4_shuffled)
copy_or_rebind_param(layer, "gemm2_scales_fp4_shuffled", gemm2_scales_fp4_shuffled)

# Compute additional scaling factor needed for TRT-LLM
w2_input_scale_quant = cast(torch.Tensor, layer.w2_input_scale_quant)
Expand All @@ -294,14 +294,6 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
(w2_input_scale_quant * g1_alphas).to(torch.float32),
)

# Clean up weights that won't be used by TRT-LLM
del (
layer.w2_weight,
layer.w2_weight_scale,
layer.w13_weight,
layer.w13_weight_scale,
)


@dataclass
class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo):
Expand Down Expand Up @@ -560,11 +552,10 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
class FlashInferTrtllmFp4MoeQuantInfo(MoeQuantInfo):
"""Quantization payload consumed by FlashInfer TRT-LLM FP4 MoE kernels."""

# Shuffled FP4 weights (processed by align_fp4_moe_weights_for_flashinfer_trtllm)
gemm1_weights_fp4_shuffled: torch.Tensor
gemm2_weights_fp4_shuffled: torch.Tensor
gemm1_scales_fp4_shuffled: torch.Tensor
gemm2_scales_fp4_shuffled: torch.Tensor
w13_weight: torch.Tensor
w2_weight: torch.Tensor
w13_weight_scale: torch.Tensor
w2_weight_scale: torch.Tensor

# Scaling factors
g1_scale_c: torch.Tensor
Expand Down Expand Up @@ -666,18 +657,14 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(
routing_bias=None,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=quant_info.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_weights=quant_info.w13_weight,
gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
Expand Down Expand Up @@ -716,18 +703,14 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=quant_info.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_weights=quant_info.w13_weight,
gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sglang.srt.layers.quantization.utils import (
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
replace_parameter,
swizzle_blockscale,
)
from sglang.srt.utils import next_power_of_2, set_weight_attrs
Expand Down Expand Up @@ -257,30 +258,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)
logger.debug("Finished shuffling weights for TRT-LLM MOE")

layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
replace_parameter(layer, "w13_weight", gemm1_weights_fp4_shuffled)
replace_parameter(layer, "w2_weight", gemm2_weights_fp4_shuffled)
replace_parameter(layer, "w13_weight_scale", gemm1_scales_fp4_shuffled)
replace_parameter(layer, "w2_weight_scale", gemm2_scales_fp4_shuffled)

# Additional parameter needed for TRT-LLM
layer.g1_scale_c = torch.nn.Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)

# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
else:
# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
Expand Down Expand Up @@ -370,18 +357,14 @@ def apply_weights(
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=layer.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c,
output1_scale_gate_scalar=layer.g1_alphas,
Expand Down
13 changes: 6 additions & 7 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,9 +1980,8 @@ def apply(
), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}"
moe_runner_config = self.moe_runner_config

# FlashInfer TRTLLM FP4 path - layer has shuffled weights only when
# backend is flashinfer_trtllm
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# FlashInfer TRTLLM FP4 path
if self.enable_flashinfer_trtllm_moe and hasattr(layer, "g1_scale_c"):
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
FlashInferTrtllmFp4MoeQuantInfo,
)
Expand All @@ -1994,10 +1993,10 @@ def apply(
)

quant_info = FlashInferTrtllmFp4MoeQuantInfo(
gemm1_weights_fp4_shuffled=layer.gemm1_weights_fp4_shuffled.data,
gemm2_weights_fp4_shuffled=layer.gemm2_weights_fp4_shuffled.data,
gemm1_scales_fp4_shuffled=layer.gemm1_scales_fp4_shuffled.data,
gemm2_scales_fp4_shuffled=layer.gemm2_scales_fp4_shuffled.data,
w13_weight=layer.w13_weight.data,
w2_weight=layer.w2_weight.data,
w13_weight_scale=layer.w13_weight_scale.data,
w2_weight_scale=layer.w2_weight_scale.data,
g1_scale_c=layer.g1_scale_c.data,
g1_alphas=layer.g1_alphas.data,
g2_alphas=layer.g2_alphas.data,
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,8 +2780,9 @@ def _handle_moe_kernel_config(self):
assert self.quantization in [
"fp8",
"mxfp8",
"modelopt_fp4",
None,
], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', or bfloat16 (None)."
], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', 'modelopt_fp4', or bfloat16 (None)."
self.disable_shared_experts_fusion = True
logger.warning(
"FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set."
Expand Down
55 changes: 55 additions & 0 deletions test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,49 @@ def test_gsm8k(self):
self.assertGreater(metrics["score"], 0.93)


class FlashinferTrtllmGenMoeBackendNVFP4Base:
backend = None

@classmethod
def setUpClass(cls):
cls.model = "nvidia/Qwen3-30B-A3B-NVFP4"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"},
other_args=[
"--moe-runner-backend",
cls.backend,
"--tp-size",
"4",
"--ep-size",
"4",
"--mem-fraction-static",
"0.7",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="gsm8k",
api="completion",
max_tokens=512,
num_examples=200,
num_threads=128,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.89)


class TestFlashinferTrtllmGenMoeBackendFP8(
FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase
):
Expand All @@ -175,6 +218,12 @@ class TestFlashinferTrtllmGenMoeBackendBF16(
backend = "flashinfer_trtllm"


class TestFlashinferTrtllmGenMoeBackendNVFP4(
FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase
):
backend = "flashinfer_trtllm"


class TestFlashinferTrtllmGenMoeBackendFP8Routed(
FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase
):
Expand All @@ -193,5 +242,11 @@ class TestFlashinferTrtllmGenMoeBackendBF16Routed(
backend = "flashinfer_trtllm_routed"


class TestFlashinferTrtllmGenMoeBackendNVFP4Routed(
FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase
):
backend = "flashinfer_trtllm_routed"


if __name__ == "__main__":
unittest.main()
Loading
Loading