diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py index 71651efb6b0f..fdd3c9622ad2 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py @@ -8,13 +8,21 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( + FlashInferTrtllmFp8MoeQuantInfo, +) from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo +from sglang.srt.layers.moe.utils import get_moe_runner_backend from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsMoEScheme, ) from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz -from sglang.srt.layers.quantization.utils import all_close_1d, per_tensor_dequantize +from sglang.srt.layers.quantization.utils import ( + all_close_1d, + per_tensor_dequantize, + swap_w13_to_w31, +) from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs if TYPE_CHECKING: @@ -43,6 +51,7 @@ class CompressedTensorsW8A8Fp8MoE(CompressedTensorsMoEScheme): def __init__(self, weight_quant, input_quant): self.weight_quant = weight_quant self.input_quant = input_quant + self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() per_tensor = ( self.weight_quant.strategy == QuantizationStrategy.TENSOR @@ -305,11 +314,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No ) torch.cuda.empty_cache() + if ( + self.weight_quant.strategy == QuantizationStrategy.BLOCK + and self.use_flashinfer_trtllm + ): + layer.w13_weight = torch.nn.Parameter( + swap_w13_to_w31(layer.w13_weight.data), + requires_grad=False, + ) + layer.w13_weight_scale = torch.nn.Parameter( + swap_w13_to_w31(layer.w13_weight_scale.data), + requires_grad=False, + ) + def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + moe_runner_backend = get_moe_runner_backend() + if moe_runner_backend.is_auto(): + moe_runner_backend = MoeRunnerBackend.TRITON + self.runner = MoeRunner(moe_runner_backend, moe_runner_config) def apply_weights( self, @@ -358,16 +383,31 @@ def apply_weights( ) return StandardCombineInput(hidden_states=output) elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - use_fp8_w8a8=True, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.weight_block_size, - ) + if self.use_flashinfer_trtllm: + quant_info = FlashInferTrtllmFp8MoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + global_num_experts=layer.num_experts, + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + intermediate_size=layer.w2_weight.shape[2], + routing_method_type=layer.routing_method_type, + block_quant=self.block_quant, + weight_block_k=self.weight_block_size[1], + w13_weight_scale_inv=layer.w13_weight_scale, + w2_weight_scale_inv=layer.w2_weight_scale, + ) + else: + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size, + ) return self.runner.run(dispatch_output, quant_info) else: quant_info = TritonMoeQuantInfo( diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index d98381d36903..a3346d6ae836 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -594,6 +594,12 @@ def swizzle_blockscale(scale: torch.Tensor): ) +def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: + return ( + x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape) + ) + + def reorder_w1w3_to_w3w1( weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/test/registered/8-gpu-models/test_mistral_large3.py b/test/registered/8-gpu-models/test_mistral_large3.py index b7fc0b9d8033..eaf415289150 100644 --- a/test/registered/8-gpu-models/test_mistral_large3.py +++ b/test/registered/8-gpu-models/test_mistral_large3.py @@ -46,6 +46,7 @@ def test_mistral_large3_all_variants(self): base_args = [ "--tp=8", "--attention-backend=trtllm_mla", + "--moe-runner-backend=flashinfer_trtllm", "--model-loader-extra-config", '{"enable_multithread_load": true}', "--chat-template=mistral", @@ -58,10 +59,6 @@ def test_mistral_large3_all_variants(self): "--speculative-num-draft-tokens=4", "--kv-cache-dtype=auto", ] - # TODO: add this to base args when FP8 TRTLLM moe is supported - nvfp4_args = [ - "--moe-runner-backend=flashinfer_trtllm", - ] variants = [ # Variant: "basic" - FP8 model + TP=8 + trtllm_mla backend @@ -83,7 +80,7 @@ def test_mistral_large3_all_variants(self): ModelLaunchSettings( MISTRAL_LARGE3_NVFP4_MODEL_PATH, tp_size=8, - extra_args=base_args + nvfp4_args, + extra_args=base_args, variant="NVFP4", ), ]