Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -461,123 +461,114 @@ def apply(
dispatch_output: StandardDispatchOutput,
) -> CombineInput:

from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids

output = cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_weight_scale,
w1_alphas=layer.g1_alphas,
a2_gscale=layer.w2_input_scale_quant,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_weight_scale,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)

return StandardCombineInput(hidden_states=output)

def apply_with_router_logits(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
assert self.use_flashinfer_trtllm

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output

from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe
if self.use_flashinfer_trtllm:
from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe

from sglang.srt.layers.moe.utils import RoutingMethodType
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
# Quantize input hidden states using fp4_quantize
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
x,
layer.w13_input_scale_quant,
self.group_size, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)
hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2)
hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)

# Quantize input hidden states using fp4_quantize
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
x,
layer.w13_input_scale_quant,
self.group_size, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)
hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2)
hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(x.dtype)
)

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(x.dtype)
)
assert layer.routing_method_type is not None

assert layer.routing_method_type is not None
# DeepSeekV3 style routing requires float32 router logits
if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

# DeepSeekV3 style routing requires float32 router logits
if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
routed_scaling_factor = (
routed_scaling_factor if routed_scaling_factor is not None else 1.0
)

routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
routed_scaling_factor = (
routed_scaling_factor if routed_scaling_factor is not None else 1.0
)
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
num_tokens = hs_fp4.shape[0]
hidden_size = (
hs_fp4.shape[-1] * 2
if hs_fp4.dtype == torch.uint8
else hs_fp4.shape[-1]
)
symm_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
)

with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
num_tokens = hs_fp4.shape[0]
hidden_size = (
hs_fp4.shape[-1] * 2
if hs_fp4.dtype == torch.uint8
else hs_fp4.shape[-1]
)
symm_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
)
output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=layer.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c,
output1_scale_gate_scalar=layer.g1_alphas,
output2_scale_scalar=layer.g2_alphas,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=layer.routing_method_type,
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]
else:
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4

topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids

output = cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_weight_scale,
w1_alphas=layer.g1_alphas,
a2_gscale=layer.w2_input_scale_quant,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_weight_scale,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)

return trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=layer.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c,
output1_scale_gate_scalar=layer.g1_alphas,
output2_scale_scalar=layer.g2_alphas,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=layer.routing_method_type,
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]
return StandardCombineInput(hidden_states=output)


class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
Expand Down
29 changes: 21 additions & 8 deletions test/registered/8-gpu-models/test_mistral_large3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@

# Runs on both H200 and B200 via nightly-8-gpu-common suite
# Note: trtllm_mla backend may have hardware-specific behavior
register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True)
register_cuda_ci(est_time=3000, suite="nightly-8-gpu-common", nightly=True)

MISTRAL_LARGE3_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512"
MISTRAL_LARGE3_FP8_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512"
MISTRAL_LARGE3_NVFP4_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4"
MISTRAL_LARGE3_EAGLE_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle"


@unittest.skipIf(not is_blackwell_system(), "Requires B200")
class TestMistralLarge3(unittest.TestCase):
"""Unified test class for Mistral-Large-3 performance and accuracy.

Two variants:
- basic: TP=8 + trtllm_mla backend
Three variants:
- basic: FP8 model + TP=8 + trtllm_mla backend
- eagle: basic + EAGLE speculative decoding with draft model
- nvfp4: NVFP4 model + TP=8 + trtllm_mla backend

Each variant runs BOTH:
- Performance test (using NightlyBenchmarkRunner)
Expand Down Expand Up @@ -56,22 +58,33 @@ def test_mistral_large3_all_variants(self):
"--speculative-num-draft-tokens=4",
"--kv-cache-dtype=auto",
]
# TODO: add this to base args when FP8 TRTLLM moe is supported
nvfp4_args = [
"--moe-runner-backend=flashinfer_trtllm",
]

variants = [
# Variant: "basic" - TP=8 + trtllm_mla backend
# Variant: "basic" - FP8 model + TP=8 + trtllm_mla backend
ModelLaunchSettings(
MISTRAL_LARGE3_MODEL_PATH,
MISTRAL_LARGE3_FP8_MODEL_PATH,
tp_size=8,
extra_args=base_args,
variant="TP8",
),
# Variant: "eagle" - TP=8 + trtllm_mla + EAGLE with draft model
# Variant: "eagle" - FP8 model + TP=8 + trtllm_mla + EAGLE with draft model
ModelLaunchSettings(
MISTRAL_LARGE3_MODEL_PATH,
MISTRAL_LARGE3_FP8_MODEL_PATH,
tp_size=8,
extra_args=base_args + eagle_args,
variant="TP8+MTP",
),
# Variant: "nvfp4" - NVFP4 model + TP=8 + trtllm_mla backend
ModelLaunchSettings(
MISTRAL_LARGE3_NVFP4_MODEL_PATH,
tp_size=8,
extra_args=base_args + nvfp4_args,
variant="NVFP4",
),
]

run_combined_tests(
Expand Down
Loading