From 640ade10d6fe988150457fc5323cee3ebf4bcff3 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sat, 7 Feb 2026 23:12:07 +0000 Subject: [PATCH 01/14] WIP Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- .../moe/moe_runner/flashinfer_trtllm.py | 101 +++++++++++++++++- python/sglang/srt/layers/quantization/fp8.py | 10 ++ 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 3aa20b17a17a..30b9f78de305 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -181,6 +181,76 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: ) +def align_mxfp8_moe_weights_for_flashinfer_trtllm( + layer: Module, quantize: bool = True +) -> None: + """Prepare MXFP8 MoE weights/scales for FlashInfer TRT-LLM kernels. + + For the trtllm MXFP8 path, weights must be: + 1. Quantized to MXFP8 using flashinfer's mxfp8_quantize (swizzled scales) + 2. Rows reordered for gated activation (w13 only) + 3. Shuffled for transposed MMA output + + Args: + layer: The MoE layer to process. + quantize: If True, quantize BF16 weights to MXFP8. + If False, assume weights are already MXFP8 quantized. + """ + from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a + from flashinfer.fp8_quantization import mxfp8_quantize + + w13_weight = cast(torch.Tensor, layer.w13_weight) + w2_weight = cast(torch.Tensor, layer.w2_weight) + num_experts = w13_weight.shape[0] + + if quantize: + # Quantize BF16 weights to MXFP8 using flashinfer (swizzled scales for weights) + w13_q_list, w13_s_list = [], [] + w2_q_list, w2_s_list = [], [] + for i in range(num_experts): + w13_q, w13_s = mxfp8_quantize(w13_weight[i], is_sf_swizzled_layout=True) + w13_q_list.append(w13_q) + w13_s_list.append(w13_s.view(torch.uint8)) + + w2_q, w2_s = mxfp8_quantize(w2_weight[i], is_sf_swizzled_layout=True) + w2_q_list.append(w2_q) + w2_s_list.append(w2_s.view(torch.uint8)) + + w13_weight = torch.stack(w13_q_list) + w13_scales = torch.stack(w13_s_list) + w2_weight = torch.stack(w2_q_list) + w2_scales = torch.stack(w2_s_list) + else: + # Already quantized checkpoint — scales are already in layer + w13_scales = cast(torch.Tensor, layer.w13_weight_scale_inv) + w2_scales = cast(torch.Tensor, layer.w2_weight_scale_inv) + + # Reorder rows for gated activation (w13 only — interleaves gate/up halves) + w13_interleaved = torch.stack( + [reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)] + ).reshape_as(w13_weight) + + # Shuffle weights for transposed MMA output (both w13 and w2) + epilogue_tile_m = 128 + w13_shuffled = torch.stack( + [ + shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m) + for i in range(num_experts) + ] + ).view(torch.float8_e4m3fn) + w2_shuffled = torch.stack( + [ + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m) + for i in range(num_experts) + ] + ).view(torch.float8_e4m3fn) + + layer.w13_weight = Parameter(w13_shuffled, requires_grad=False) + layer.w2_weight = Parameter(w2_shuffled, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_scales, requires_grad=False) + layer.w2_weight_scale_inv = Parameter(w2_scales, requires_grad=False) + + @dataclass class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): """Quantization payload consumed by FlashInfer TRT-LLM FP8 MoE kernels.""" @@ -210,6 +280,9 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): output2_scales_scalar: torch.Tensor | None = None use_routing_scales_on_input: bool = False + # MXFP8 path + use_mxfp8: bool = False + def fused_experts_none_to_flashinfer_trtllm_fp8( dispatch_output: StandardDispatchOutput, @@ -217,6 +290,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( runner_config: MoeRunnerConfig, ) -> StandardCombineInput: from flashinfer.fused_moe import ( + Fp8QuantizationType, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, ) @@ -243,12 +317,28 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( routing_method_type = quant_info.routing_method_type if quant_info.block_quant: - assert quant_info.weight_block_k is not None assert quant_info.w13_weight_scale_inv is not None assert quant_info.w2_weight_scale_inv is not None - a_q, a_sf = per_token_group_quant_fp8(hidden_states, quant_info.weight_block_k) - a_sf_t = a_sf.t().contiguous() + if quant_info.use_mxfp8: + # MXFP8 path: quantize activations with flashinfer's mxfp8_quantize + from flashinfer.fp8_quantization import mxfp8_quantize + + a_q, a_sf = mxfp8_quantize( + hidden_states, is_sf_swizzled_layout=False + ) + a_sf = a_sf.view(torch.uint8).reshape(hidden_states.shape[0], -1) + fp8_quant_type = Fp8QuantizationType.MxFp8 + use_shuffled_weight = True + else: + # DeepSeek FP8 block scale path + assert quant_info.weight_block_k is not None + a_q, a_sf = per_token_group_quant_fp8( + hidden_states, quant_info.weight_block_k + ) + a_sf = a_sf.t().contiguous() + fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 + use_shuffled_weight = False with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() @@ -265,7 +355,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( ), routing_bias=correction_bias, hidden_states=a_q, - hidden_states_scale=a_sf_t, + hidden_states_scale=a_sf, gemm1_weights=quant_info.w13_weight, gemm1_weights_scale=quant_info.w13_weight_scale_inv, gemm2_weights=quant_info.w2_weight, @@ -285,8 +375,9 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( else 1.0 ), routing_method_type=routing_method_type, - use_shuffled_weight=False, + use_shuffled_weight=use_shuffled_weight, tune_max_num_tokens=next_power_of_2(a_q.shape[0]), + fp8_quantization_type=fp8_quant_type, ) else: assert quant_info.w13_input_scale is not None diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index cb3ca7e0d856..0d450afdf08a 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -972,6 +972,15 @@ def _process_mxfp8_moe_weights(self, layer: Module, quantize: bool = True) -> No if not (_is_cuda and is_sm100_supported()): raise RuntimeError("MXFP8 MoE quantization requires SM100.") + # For trtllm backend, use flashinfer-native MXFP8 weight preparation + if get_moe_runner_backend().is_flashinfer_trtllm(): + from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( + align_mxfp8_moe_weights_for_flashinfer_trtllm, + ) + + align_mxfp8_moe_weights_for_flashinfer_trtllm(layer, quantize=quantize) + return + def _quantize_and_swizzle_with_cutlass_es_kernel(weight: torch.Tensor): from sgl_kernel import es_sm100_mxfp8_blockscaled_grouped_quant @@ -1499,6 +1508,7 @@ def apply( if not self.block_quant else None ), + use_mxfp8=self.use_mxfp8, ) elif self.runner.runner_backend.is_triton(): quant_info = TritonMoeQuantInfo( From cce9f2f0b992d9ae533b33b7be69685338a85f68 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sat, 7 Feb 2026 23:15:56 +0000 Subject: [PATCH 02/14] WIP Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- .../srt/layers/moe/fused_moe_triton/layer.py | 15 +- .../srt/layers/quantization/__init__.py | 2 + .../srt/layers/quantization/modelopt_quant.py | 359 ++++++++++++++++++ 3 files changed, 375 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 185f1bea3ea5..fda6f3b630f7 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -704,6 +704,7 @@ def _weight_loader_impl( # Flashinfer assumes w31 format for w13_weight. Same for the scales. if self.use_flashinfer_trtllm_moe and ( isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + or isinstance(self.quant_method, ModelOptMxfp8MoEMethod) or isinstance(self.quant_method, Fp8MoEMethod) or isinstance(self.quant_method, UnquantizedFusedMoEMethod) or isinstance(self.quant_method, CompressedTensorsMxInt4MoEMethod) @@ -768,8 +769,20 @@ def _weight_loader_impl( return if "ModelOpt" in self.quant_method.__class__.__name__: - # Determine per-tensor weight scale patterns based on variant is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + is_mxfp8_variant = isinstance(self.quant_method, ModelOptMxfp8MoEMethod) + + if is_mxfp8_variant: + # MXFP8: weight_scale is block scale (uint8 UE8M0), not per-tensor + if "weight_scale" in weight_name or "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + return # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor per_tensor_conditions = ( diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 9174c08ff559..c3fde3576432 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -31,6 +31,7 @@ def override_quantization_method(self, *args, **kwargs): from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp4Config, ModelOptFp8Config, + ModelOptMxfp8Config, ) from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config @@ -57,6 +58,7 @@ def override_quantization_method(self, *args, **kwargs): "modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8 "modelopt_fp8": ModelOptFp8Config, "modelopt_fp4": ModelOptFp4Config, + "modelopt_mxfp8": ModelOptMxfp8Config, "w8a8_int8": W8A8Int8Config, "w8a8_fp8": W8A8Fp8Config, "awq": AWQConfig, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index d6b3dd68d084..8b05b990ed0f 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -48,6 +48,7 @@ swizzle_blockscale, ) from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.utils import set_weight_attrs from sglang.srt.utils.common import ( get_bool_env_var, is_cuda, @@ -1850,3 +1851,361 @@ def apply_without_routing_weights( ), ) return out + + +class ModelOptMxfp8Config(ModelOptQuantConfig): + """Config class for ModelOpt MXFP8 quantization. + + Handles checkpoints with: + - weight: float8_e4m3fn + - weight_scale: uint8 (UE8M0 block scales, group_size=32) + - k_scale/v_scale: float32 (KV cache FP8 scales) + """ + + def __init__( + self, + kv_cache_quant_algo: Optional[str] = None, + exclude_modules: Optional[List[str]] = None, + packed_modules_mapping: Optional[Dict[str, List[str]]] = None, + ) -> None: + super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping) + + @classmethod + def override_quantization_method(cls, hf_quant_config, user_quant): + return cls._modelopt_override_quantization_method(hf_quant_config, user_quant) + + @classmethod + def get_name(cls) -> str: + return "modelopt_mxfp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 100 # Blackwell required + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptMxfp8Config": + kv_cache_quant_method = None + exclude_modules = None + + quant_method = config.get("quant_algo") + if quant_method is not None: + kv_cache_scheme = config.get("kv_cache_scheme") + if ( + kv_cache_scheme + and kv_cache_scheme.get("type") == "float" + and kv_cache_scheme.get("num_bits") == 8 + ): + kv_cache_quant_method = "FP8" + exclude_modules = config.get("ignore") + else: + try: + quantization_section = cls.get_from_keys(config, ["quantization"]) + quant_method = quantization_section.get("quant_algo") + kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo") + exclude_modules = quantization_section.get("exclude_modules") + except ValueError: + raise ValueError( + "Cannot find 'quant_algo' in the model's quantization config." + ) + + if quant_method is None or "MXFP8" not in quant_method: + raise ValueError( + "ModelOptMxfp8Config only supports MXFP8 quantization. " + f"Got quant_algo={quant_method}." + ) + + return cls( + kv_cache_quant_algo=kv_cache_quant_method, + exclude_modules=exclude_modules, + packed_modules_mapping=config.get("packed_modules_mapping"), + ) + + def is_layer_excluded(self, prefix: str) -> bool: + if not self.exclude_modules: + return False + return any( + module in prefix + or ( + prefix.startswith("language_model.") + and module in prefix.removeprefix("language_model.") + ) + for module in self.exclude_modules + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + return self._get_quant_method( + layer, + prefix, + Linear=ModelOptMxfp8LinearMethod, + Moe=ModelOptMxfp8MoEMethod, + ) + + +class ModelOptMxfp8LinearMethod(LinearMethodBase): + """Linear method for ModelOpt MXFP8 quantization. + + Loads FP8 weights with uint8 UE8M0 block scales (group_size=32). + Uses the existing MXFP8 blockscaled linear infrastructure. + """ + + BLOCK_K = 32 # MXFP8 group size + + def __init__(self, quant_config: ModelOptMxfp8Config): + super().__init__() + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.parameter import BlockQuantScaleParameter + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight: fp8 + layer.register_parameter( + "weight", + ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + # Block scales: uint8 UE8M0, named "weight_scale" to match ModelOpt checkpoint + scale = BlockQuantScaleParameter( + data=torch.zeros( + output_size_per_partition, + input_size_per_partition // self.BLOCK_K, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale.format_ue8m0 = True + layer.register_parameter("weight_scale", scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Rename weight_scale → weight_scale_inv for downstream MXFP8 code + layer.weight_scale_inv = Parameter( + layer.weight_scale.data, requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from sglang.srt.layers.quantization.fp8_utils import ( + triton_mxfp8_blockscaled_linear, + ) + + if isinstance(x, tuple): + return triton_mxfp8_blockscaled_linear( + input=x[0], + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + input_scale=x[1], + bias=bias, + ) + return triton_mxfp8_blockscaled_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + bias=bias, + ) + + +class ModelOptMxfp8MoEMethod(FusedMoEMethodBase): + """MoE method for ModelOpt MXFP8 quantization. + + Loads FP8 expert weights with uint8 UE8M0 block scales. + Uses the existing MXFP8 MoE infrastructure for inference. + """ + + BLOCK_K = 32 # MXFP8 group size + + def __init__(self, quant_config: ModelOptMxfp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + weight_loader = extra_weight_attrs.get("weight_loader") + num_shards = 2 if layer.moe_runner_config.is_gated else 1 + intermediate_size = num_shards * intermediate_size_per_partition + + # Weights: fp8 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + intermediate_size, + hidden_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Block scales: uint8 UE8M0, named "weight_scale" to match ModelOpt checkpoint + # Shape: [num_experts, N, K // 32] + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + intermediate_size, + hidden_size // self.BLOCK_K, + dtype=torch.uint8, + ), + requires_grad=False, + ) + w13_weight_scale.format_ue8m0 = True + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // self.BLOCK_K, + dtype=torch.uint8, + ), + requires_grad=False, + ) + w2_weight_scale.format_ue8m0 = True + + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Set weight_loader and quant_method on scale params for proper loading + extra_weight_attrs_block = dict(extra_weight_attrs) + extra_weight_attrs_block["quant_method"] = ( + FusedMoeWeightScaleSupported.BLOCK.value + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs_block) + set_weight_attrs(w2_weight_scale, extra_weight_attrs_block) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Rename weight_scale → weight_scale_inv for downstream MXFP8 code + layer.w13_weight_scale_inv = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale_inv = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + # Use existing MXFP8 MoE weight preparation infrastructure + if get_moe_runner_backend().is_flashinfer_trtllm(): + from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( + align_mxfp8_moe_weights_for_flashinfer_trtllm, + ) + + # Weights are already quantized FP8 from checkpoint — no re-quantization + align_mxfp8_moe_weights_for_flashinfer_trtllm(layer, quantize=False) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + if get_moe_runner_backend().is_flashinfer_trtllm(): + self.runner = MoeRunner( + MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config + ) + else: + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: "StandardDispatchOutput", + ) -> "CombineInput": + from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( + FlashInferTrtllmFp8MoeQuantInfo, + fused_experts_none_to_flashinfer_trtllm_fp8, + ) + from sglang.srt.layers.moe.topk import TopKOutputChecker + from sglang.srt.layers.moe.utils import RoutingMethodType + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + if ( + get_moe_runner_backend().is_flashinfer_trtllm() + and TopKOutputChecker.format_is_bypassed(topk_output) + ): + topk_config = topk_output.topk_config + + quant_info = FlashInferTrtllmFp8MoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + global_num_experts=layer.num_experts, + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + intermediate_size=layer.w2_weight.shape[2], + routing_method_type=RoutingMethodType.Default, + block_quant=True, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + use_mxfp8=True, + ) + + return fused_experts_none_to_flashinfer_trtllm_fp8( + dispatch_output, quant_info, self.moe_runner_config + ) + + # Fallback: triton path + from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo + + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + per_channel_quant=False, + w13_scale=layer.w13_weight_scale_inv, + w2_scale=layer.w2_weight_scale_inv, + ) + return self.runner.run(dispatch_output, quant_info) From 403ab9431ece331a23298fc2d5d1caceeee6876e Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sat, 7 Feb 2026 23:46:53 +0000 Subject: [PATCH 03/14] WIP Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- python/sglang/srt/configs/model_config.py | 1 + .../srt/layers/moe/fused_moe_triton/layer.py | 5 +- .../moe/moe_runner/flashinfer_trtllm.py | 44 +++++++------- .../srt/layers/quantization/base_config.py | 12 +++- .../srt/layers/quantization/modelopt_quant.py | 60 +++++++++++++------ python/sglang/srt/server_args.py | 6 +- 6 files changed, 82 insertions(+), 46 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 83d4f527f70c..3394f6bb356c 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -837,6 +837,7 @@ def _verify_quantization(self) -> None: compatible_quantization_methods = { "modelopt_fp8": ["modelopt"], "modelopt_fp4": ["modelopt"], + "modelopt_mxfp8": ["modelopt"], "petit_nvfp4": ["modelopt"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index fda6f3b630f7..dddd2df3c564 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -58,7 +58,10 @@ CompressedTensorsMxInt4MoEMethod, ) from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod -from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod +from sglang.srt.layers.quantization.modelopt_quant import ( + ModelOptMxfp8MoEMethod, + ModelOptNvFp4FusedMoEMethod, +) from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.server_args import get_global_server_args diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 30b9f78de305..f87d3a134a38 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -203,27 +203,29 @@ def align_mxfp8_moe_weights_for_flashinfer_trtllm( w2_weight = cast(torch.Tensor, layer.w2_weight) num_experts = w13_weight.shape[0] - if quantize: - # Quantize BF16 weights to MXFP8 using flashinfer (swizzled scales for weights) - w13_q_list, w13_s_list = [], [] - w2_q_list, w2_s_list = [], [] - for i in range(num_experts): - w13_q, w13_s = mxfp8_quantize(w13_weight[i], is_sf_swizzled_layout=True) - w13_q_list.append(w13_q) - w13_s_list.append(w13_s.view(torch.uint8)) - - w2_q, w2_s = mxfp8_quantize(w2_weight[i], is_sf_swizzled_layout=True) - w2_q_list.append(w2_q) - w2_s_list.append(w2_s.view(torch.uint8)) - - w13_weight = torch.stack(w13_q_list) - w13_scales = torch.stack(w13_s_list) - w2_weight = torch.stack(w2_q_list) - w2_scales = torch.stack(w2_s_list) - else: - # Already quantized checkpoint — scales are already in layer - w13_scales = cast(torch.Tensor, layer.w13_weight_scale_inv) - w2_scales = cast(torch.Tensor, layer.w2_weight_scale_inv) + if not quantize: + # Pre-quantized FP8 checkpoint: upcast to BF16 so mxfp8_quantize can + # re-derive properly swizzled 1D scale factors (the checkpoint stores + # raw 2D UE8M0 scales which the trtllm kernel cannot consume directly). + w13_weight = w13_weight.to(torch.bfloat16) + w2_weight = w2_weight.to(torch.bfloat16) + + # Quantize (or re-quantize) to MXFP8 with swizzled scales for weights + w13_q_list, w13_s_list = [], [] + w2_q_list, w2_s_list = [], [] + for i in range(num_experts): + w13_q, w13_s = mxfp8_quantize(w13_weight[i], is_sf_swizzled_layout=True) + w13_q_list.append(w13_q) + w13_s_list.append(w13_s.view(torch.uint8)) + + w2_q, w2_s = mxfp8_quantize(w2_weight[i], is_sf_swizzled_layout=True) + w2_q_list.append(w2_q) + w2_s_list.append(w2_s.view(torch.uint8)) + + w13_weight = torch.stack(w13_q_list) + w13_scales = torch.stack(w13_s_list) + w2_weight = torch.stack(w2_q_list) + w2_scales = torch.stack(w2_s_list) # Reorder rows for gated activation (w13 only — interleaves gate/up halves) w13_interleaved = torch.stack( diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 48511d09f0bf..e3a55ae61f38 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -171,16 +171,22 @@ def _modelopt_override_quantization_method( # If user specified generic "modelopt", auto-detect the specific method if user_quant == "modelopt": - if "FP8" in quant_algo: + # Check MXFP8 before FP8 since "MXFP8" contains "FP8" + if "MXFP8" in quant_algo: + return "modelopt_mxfp8" + elif "FP8" in quant_algo: return "modelopt_fp8" elif "NVFP4" in quant_algo or "FP4" in quant_algo: return "modelopt_fp4" # The hf_quant_config may be a parsed quant config, so we need to check the # quant_method. - if hf_quant_config.get("quant_method", "") == "modelopt_fp8": + quant_method = hf_quant_config.get("quant_method", "") + if quant_method == "modelopt_mxfp8": + return "modelopt_mxfp8" + elif quant_method == "modelopt_fp8": return "modelopt_fp8" - elif hf_quant_config.get("quant_method", "") == "modelopt_fp4": + elif quant_method == "modelopt_fp4": return "modelopt_fp4" return None diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 8b05b990ed0f..831a824ace1a 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1951,7 +1951,11 @@ class ModelOptMxfp8LinearMethod(LinearMethodBase): """Linear method for ModelOpt MXFP8 quantization. Loads FP8 weights with uint8 UE8M0 block scales (group_size=32). - Uses the existing MXFP8 blockscaled linear infrastructure. + Uses flashinfer's bmm_mxfp8 (cuDNN) for inference. + + During process_weights_after_loading, the weight [N, K] is transposed to + [K, N] and re-quantized with mxfp8_quantize so that the MXFP8 block scales + align with the K-contiguous layout expected by bmm_mxfp8. """ BLOCK_K = 32 # MXFP8 group size @@ -2007,9 +2011,19 @@ def create_weights( layer.register_parameter("weight_scale", scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Rename weight_scale → weight_scale_inv for downstream MXFP8 code + from flashinfer.fp8_quantization import mxfp8_quantize + + # Checkpoint weight is [N, K] fp8 with scales for [N, K] layout. + # bmm_mxfp8 needs B as [K, N] (column-major) with scales for that layout. + # Re-quantize the transposed weight to get correct MXFP8 block scales. + weight_bf16 = layer.weight.data.to(torch.bfloat16) # [N, K] + weight_t = weight_bf16.T.contiguous() # [K, N] row-major + weight_t_q, weight_t_scale = mxfp8_quantize( + weight_t, is_sf_swizzled_layout=False + ) + layer.weight = Parameter(weight_t_q, requires_grad=False) # [K, N] fp8 layer.weight_scale_inv = Parameter( - layer.weight_scale.data, requires_grad=False + weight_t_scale.view(torch.uint8), requires_grad=False ) def apply( @@ -2018,24 +2032,34 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from sglang.srt.layers.quantization.fp8_utils import ( - triton_mxfp8_blockscaled_linear, - ) + from flashinfer import bmm_mxfp8 + from flashinfer.fp8_quantization import mxfp8_quantize + input_2d = x.view(-1, x.shape[-1]) # [M, K] + output_shape = [*x.shape[:-1], layer.weight.shape[1]] # [..., N] + + # Quantize activations to MXFP8 if isinstance(x, tuple): - return triton_mxfp8_blockscaled_linear( - input=x[0], - weight=layer.weight, - weight_scale=layer.weight_scale_inv, - input_scale=x[1], - bias=bias, + input_q = input_2d + input_scale = x[1] + else: + input_q, input_scale = mxfp8_quantize( + input_2d, is_sf_swizzled_layout=False ) - return triton_mxfp8_blockscaled_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale_inv, - bias=bias, + + # bmm_mxfp8: A [1, M, K] @ B [1, K, N] → [1, M, N] + out = bmm_mxfp8( + input_q.unsqueeze(0), + layer.weight.unsqueeze(0), + input_scale, + layer.weight_scale_inv, + dtype=torch.bfloat16, ) + out = out.squeeze(0) # [M, N] + + if bias is not None: + out = out + bias + return out.view(*output_shape) class ModelOptMxfp8MoEMethod(FusedMoEMethodBase): @@ -2089,9 +2113,7 @@ def create_weights( weight_loader=weight_loader, ) layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) # Block scales: uint8 UE8M0, named "weight_scale" to match ModelOpt checkpoint # Shape: [num_experts, N, K // 32] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index caca351c1206..3fb280ec7138 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -102,6 +102,7 @@ "bitsandbytes", "gguf", "modelopt", + "modelopt_mxfp8", "modelopt_fp8", "modelopt_fp4", "petit_nvfp4", @@ -1297,7 +1298,7 @@ def _handle_model_specific_adjustments(self): if ( self.moe_a2a_backend == "none" and self.moe_runner_backend == "auto" - and self.quantization in ["fp8", "modelopt_fp8", "modelopt_fp4"] + and self.quantization in ["fp8", "modelopt_fp8", "modelopt_mxfp8", "modelopt_fp4"] ): self.moe_runner_backend = "flashinfer_trtllm" logger.info( @@ -1456,7 +1457,7 @@ def _handle_model_specific_adjustments(self): "intel_xpu", }, f"fa3, aiter, triton, trtllm_mha or intel_xpu is required for Llama4 model but got {self.attention_backend}" if is_sm100_supported() and self.moe_runner_backend == "auto": - if self.quantization in {"fp8", "modelopt_fp8"}: + if self.quantization in {"fp8", "modelopt_fp8", "modelopt_mxfp8"}: self.moe_runner_backend = "flashinfer_trtllm" logger.info( "Use flashinfer_trtllm as MoE runner backend on SM100 for Llama4" @@ -2078,6 +2079,7 @@ def _handle_moe_kernel_config(self): "modelopt_fp4", "fp8", "modelopt_fp8", + "modelopt_mxfp8", "compressed-tensors", None, ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', 'compressed-tensors', or bfloat16 (None)." From 92ea9238fdd5cccc17e507d1449ba480429a5a92 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 04:39:18 +0000 Subject: [PATCH 04/14] WIPfix Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- .../moe/moe_runner/flashinfer_trtllm.py | 42 ++++++++++++++++--- .../srt/layers/quantization/modelopt_quant.py | 8 +++- python/sglang/srt/server_args.py | 7 ++++ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index f87d3a134a38..17501c5bc09b 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import TYPE_CHECKING, cast @@ -7,6 +8,8 @@ from torch.nn import Module from torch.nn.parameter import Parameter +logger = logging.getLogger(__name__) + from sglang.srt.distributed import get_tp_group from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, @@ -349,12 +352,41 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( # It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # so we put the whole function under the ``use_symmetric_memory`` context manager. # If the bug is fixed, we can only put the output tensor allocation under the context manager. + _rl = router_logits.to(torch.float32) if routing_method_type == RoutingMethodType.DeepSeekV3 else router_logits + logger.info( + "trtllm_fp8_block_scale_moe call:\n" + " fp8_quantization_type=%s\n" + " routing_logits: shape=%s dtype=%s\n" + " routing_bias: %s\n" + " hidden_states (a_q): shape=%s dtype=%s\n" + " hidden_states_scale (a_sf): shape=%s dtype=%s\n" + " gemm1_weights (w13): shape=%s dtype=%s\n" + " gemm1_weights_scale (w13_scale_inv): shape=%s dtype=%s\n" + " gemm2_weights (w2): shape=%s dtype=%s\n" + " gemm2_weights_scale (w2_scale_inv): shape=%s dtype=%s\n" + " num_experts=%s top_k=%s n_group=%s topk_group=%s\n" + " intermediate_size=%s local_expert_offset=%s local_num_experts=%s\n" + " routed_scaling_factor=%s routing_method_type=%s\n" + " use_shuffled_weight=%s tune_max_num_tokens=%s", + fp8_quant_type, + _rl.shape, _rl.dtype, + f"shape={correction_bias.shape} dtype={correction_bias.dtype}" if correction_bias is not None else "None", + a_q.shape, a_q.dtype, + a_sf.shape, a_sf.dtype, + quant_info.w13_weight.shape, quant_info.w13_weight.dtype, + quant_info.w13_weight_scale_inv.shape, quant_info.w13_weight_scale_inv.dtype, + quant_info.w2_weight.shape, quant_info.w2_weight.dtype, + quant_info.w2_weight_scale_inv.shape, quant_info.w2_weight_scale_inv.dtype, + quant_info.global_num_experts, topk_config.top_k, + topk_config.num_expert_group if topk_config.num_expert_group else 0, + topk_config.topk_group if topk_config.topk_group else 0, + quant_info.intermediate_size, quant_info.local_expert_offset, quant_info.local_num_experts, + runner_config.routed_scaling_factor if runner_config.routed_scaling_factor is not None else 1.0, + routing_method_type, + use_shuffled_weight, next_power_of_2(a_q.shape[0]), + ) output = trtllm_fp8_block_scale_moe( - routing_logits=( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ), + routing_logits=_rl, routing_bias=correction_bias, hidden_states=a_q, hidden_states_scale=a_sf, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 831a824ace1a..e5cb194dd3a5 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2208,7 +2208,13 @@ def apply( local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_num_experts=layer.num_local_experts, intermediate_size=layer.w2_weight.shape[2], - routing_method_type=RoutingMethodType.Default, + routing_method_type=int( + getattr( + layer, + "routing_method_type", + RoutingMethodType.Renormalize, + ) + ), block_quant=True, w13_weight_scale_inv=layer.w13_weight_scale_inv, w2_weight_scale_inv=layer.w2_weight_scale_inv, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3fb280ec7138..68455dd42341 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2055,6 +2055,13 @@ def _handle_data_parallelism(self): ), "Please enable dp attention when setting enable_dp_lm_head. " def _handle_moe_kernel_config(self): + if self.quantization == "modelopt_mxfp8" and not self.disable_cuda_graph: + logger.warning( + "Cuda graph is disabled for modelopt_mxfp8 because bmm_mxfp8 currently uses " + "a cuDNN path that is not graph-capture safe." + ) + self.disable_cuda_graph = True + if self.quantization == "mxfp8": if self.moe_runner_backend not in ["auto", "cutlass"]: logger.warning( From 10eb3a5cd211abb93c5f2ebf902c7c99941db262 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 05:09:29 +0000 Subject: [PATCH 05/14] modelopt_mxfp8: switch linear path to mm_mxfp8 and fix cuda-graph capture --- .../srt/layers/quantization/modelopt_quant.py | 87 ++++++++++++------- python/sglang/srt/server_args.py | 10 +-- 2 files changed, 60 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index e5cb194dd3a5..c199cbb53abe 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1951,11 +1951,11 @@ class ModelOptMxfp8LinearMethod(LinearMethodBase): """Linear method for ModelOpt MXFP8 quantization. Loads FP8 weights with uint8 UE8M0 block scales (group_size=32). - Uses flashinfer's bmm_mxfp8 (cuDNN) for inference. + Uses flashinfer's mm_mxfp8 for inference. - During process_weights_after_loading, the weight [N, K] is transposed to - [K, N] and re-quantized with mxfp8_quantize so that the MXFP8 block scales - align with the K-contiguous layout expected by bmm_mxfp8. + During process_weights_after_loading, checkpoint MXFP8 weights are + dequantized to bf16 and re-quantized with flashinfer's swizzled scale + layout required by mm_mxfp8. """ BLOCK_K = 32 # MXFP8 group size @@ -1964,6 +1964,19 @@ def __init__(self, quant_config: ModelOptMxfp8Config): super().__init__() self.quant_config = quant_config + @staticmethod + def _dequantize_mxfp8(data: torch.Tensor, scale_u8: torch.Tensor) -> torch.Tensor: + """Dequantize MXFP8 tensor with UE8M0 scales back to bf16.""" + group_size = 32 + m, k = data.shape + n_groups = k // group_size + scales_f32 = torch.pow( + 2.0, scale_u8.to(dtype=torch.float32, device=data.device) - 127.0 + ) + data_f32 = data.to(torch.float32).view(m, n_groups, group_size) + scales_f32 = scales_f32.view(m, n_groups, 1) + return (data_f32 * scales_f32).view(m, k).to(torch.bfloat16) + def create_weights( self, layer: torch.nn.Module, @@ -2013,17 +2026,16 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: from flashinfer.fp8_quantization import mxfp8_quantize - # Checkpoint weight is [N, K] fp8 with scales for [N, K] layout. - # bmm_mxfp8 needs B as [K, N] (column-major) with scales for that layout. - # Re-quantize the transposed weight to get correct MXFP8 block scales. - weight_bf16 = layer.weight.data.to(torch.bfloat16) # [N, K] - weight_t = weight_bf16.T.contiguous() # [K, N] row-major - weight_t_q, weight_t_scale = mxfp8_quantize( - weight_t, is_sf_swizzled_layout=False - ) - layer.weight = Parameter(weight_t_q, requires_grad=False) # [K, N] fp8 + # Checkpoint MXFP8 stores raw FP8 data and per-block UE8M0 scales. + # Reconstruct bf16 values first, then requantize to flashinfer swizzled + # scales so mm_mxfp8 can consume them directly. + weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) + weight_q, weight_scale = mxfp8_quantize(weight_bf16, is_sf_swizzled_layout=True) + + # mm_mxfp8 expects B as [K, N] column-major. + layer.weight = Parameter(weight_q.t(), requires_grad=False) layer.weight_scale_inv = Parameter( - weight_t_scale.view(torch.uint8), requires_grad=False + weight_scale.view(torch.uint8), requires_grad=False ) def apply( @@ -2032,30 +2044,47 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from flashinfer import bmm_mxfp8 + from flashinfer import mm_mxfp8 from flashinfer.fp8_quantization import mxfp8_quantize - input_2d = x.view(-1, x.shape[-1]) # [M, K] - output_shape = [*x.shape[:-1], layer.weight.shape[1]] # [..., N] + input_tensor = x[0] if isinstance(x, tuple) else x + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) # [M, K] + output_shape = [*input_tensor.shape[:-1], layer.weight.shape[1]] # [..., N] + m_actual = input_2d.shape[0] + k_dim = input_2d.shape[1] - # Quantize activations to MXFP8 + # Quantize activations to MXFP8 using swizzled scales for mm_mxfp8. + # If input is already pre-quantized, dequantize first when a 2D UE8M0 + # scale tensor is provided. if isinstance(x, tuple): - input_q = input_2d input_scale = x[1] - else: - input_q, input_scale = mxfp8_quantize( - input_2d, is_sf_swizzled_layout=False + if input_scale.dim() == 2: + input_2d = self._dequantize_mxfp8( + input_2d, + input_scale.view(-1, input_scale.shape[-1]), + ) + + # CUTLASS MXFP8 requires M >= 32. Pad decode/capture micro-batches + # and slice the output back to original M. + if m_actual < 32: + input_padded = torch.zeros( + (32, k_dim), + dtype=input_2d.dtype, + device=input_2d.device, ) + input_padded[:m_actual, :] = input_2d + input_2d = input_padded + + input_q, input_scale = mxfp8_quantize(input_2d, is_sf_swizzled_layout=True) - # bmm_mxfp8: A [1, M, K] @ B [1, K, N] → [1, M, N] - out = bmm_mxfp8( - input_q.unsqueeze(0), - layer.weight.unsqueeze(0), - input_scale, + out = mm_mxfp8( + input_q, + layer.weight, + input_scale.view(torch.uint8), layer.weight_scale_inv, - dtype=torch.bfloat16, + out_dtype=torch.bfloat16, ) - out = out.squeeze(0) # [M, N] + out = out[:m_actual, :] if bias is not None: out = out + bias diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68455dd42341..b851371f9b02 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1298,7 +1298,8 @@ def _handle_model_specific_adjustments(self): if ( self.moe_a2a_backend == "none" and self.moe_runner_backend == "auto" - and self.quantization in ["fp8", "modelopt_fp8", "modelopt_mxfp8", "modelopt_fp4"] + and self.quantization + in ["fp8", "modelopt_fp8", "modelopt_mxfp8", "modelopt_fp4"] ): self.moe_runner_backend = "flashinfer_trtllm" logger.info( @@ -2055,13 +2056,6 @@ def _handle_data_parallelism(self): ), "Please enable dp attention when setting enable_dp_lm_head. " def _handle_moe_kernel_config(self): - if self.quantization == "modelopt_mxfp8" and not self.disable_cuda_graph: - logger.warning( - "Cuda graph is disabled for modelopt_mxfp8 because bmm_mxfp8 currently uses " - "a cuDNN path that is not graph-capture safe." - ) - self.disable_cuda_graph = True - if self.quantization == "mxfp8": if self.moe_runner_backend not in ["auto", "cutlass"]: logger.warning( From 48f64f7bd0498bede6176cc5ad335ab2581e0871 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 05:11:30 +0000 Subject: [PATCH 06/14] mxfp8: pin quantize backend to cuda for modelopt and trtllm moe --- .../moe/moe_runner/flashinfer_trtllm.py | 59 ++++++++++++++----- .../srt/layers/quantization/modelopt_quant.py | 8 ++- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 17501c5bc09b..f512b3f6742d 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -217,11 +217,15 @@ def align_mxfp8_moe_weights_for_flashinfer_trtllm( w13_q_list, w13_s_list = [], [] w2_q_list, w2_s_list = [], [] for i in range(num_experts): - w13_q, w13_s = mxfp8_quantize(w13_weight[i], is_sf_swizzled_layout=True) + w13_q, w13_s = mxfp8_quantize( + w13_weight[i], is_sf_swizzled_layout=True, backend="cuda" + ) w13_q_list.append(w13_q) w13_s_list.append(w13_s.view(torch.uint8)) - w2_q, w2_s = mxfp8_quantize(w2_weight[i], is_sf_swizzled_layout=True) + w2_q, w2_s = mxfp8_quantize( + w2_weight[i], is_sf_swizzled_layout=True, backend="cuda" + ) w2_q_list.append(w2_q) w2_s_list.append(w2_s.view(torch.uint8)) @@ -330,7 +334,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( from flashinfer.fp8_quantization import mxfp8_quantize a_q, a_sf = mxfp8_quantize( - hidden_states, is_sf_swizzled_layout=False + hidden_states, is_sf_swizzled_layout=False, backend="cuda" ) a_sf = a_sf.view(torch.uint8).reshape(hidden_states.shape[0], -1) fp8_quant_type = Fp8QuantizationType.MxFp8 @@ -352,7 +356,11 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( # It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # so we put the whole function under the ``use_symmetric_memory`` context manager. # If the bug is fixed, we can only put the output tensor allocation under the context manager. - _rl = router_logits.to(torch.float32) if routing_method_type == RoutingMethodType.DeepSeekV3 else router_logits + _rl = ( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) logger.info( "trtllm_fp8_block_scale_moe call:\n" " fp8_quantization_type=%s\n" @@ -369,21 +377,40 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( " routed_scaling_factor=%s routing_method_type=%s\n" " use_shuffled_weight=%s tune_max_num_tokens=%s", fp8_quant_type, - _rl.shape, _rl.dtype, - f"shape={correction_bias.shape} dtype={correction_bias.dtype}" if correction_bias is not None else "None", - a_q.shape, a_q.dtype, - a_sf.shape, a_sf.dtype, - quant_info.w13_weight.shape, quant_info.w13_weight.dtype, - quant_info.w13_weight_scale_inv.shape, quant_info.w13_weight_scale_inv.dtype, - quant_info.w2_weight.shape, quant_info.w2_weight.dtype, - quant_info.w2_weight_scale_inv.shape, quant_info.w2_weight_scale_inv.dtype, - quant_info.global_num_experts, topk_config.top_k, + _rl.shape, + _rl.dtype, + ( + f"shape={correction_bias.shape} dtype={correction_bias.dtype}" + if correction_bias is not None + else "None" + ), + a_q.shape, + a_q.dtype, + a_sf.shape, + a_sf.dtype, + quant_info.w13_weight.shape, + quant_info.w13_weight.dtype, + quant_info.w13_weight_scale_inv.shape, + quant_info.w13_weight_scale_inv.dtype, + quant_info.w2_weight.shape, + quant_info.w2_weight.dtype, + quant_info.w2_weight_scale_inv.shape, + quant_info.w2_weight_scale_inv.dtype, + quant_info.global_num_experts, + topk_config.top_k, topk_config.num_expert_group if topk_config.num_expert_group else 0, topk_config.topk_group if topk_config.topk_group else 0, - quant_info.intermediate_size, quant_info.local_expert_offset, quant_info.local_num_experts, - runner_config.routed_scaling_factor if runner_config.routed_scaling_factor is not None else 1.0, + quant_info.intermediate_size, + quant_info.local_expert_offset, + quant_info.local_num_experts, + ( + runner_config.routed_scaling_factor + if runner_config.routed_scaling_factor is not None + else 1.0 + ), routing_method_type, - use_shuffled_weight, next_power_of_2(a_q.shape[0]), + use_shuffled_weight, + next_power_of_2(a_q.shape[0]), ) output = trtllm_fp8_block_scale_moe( routing_logits=_rl, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index c199cbb53abe..4100984836cc 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2030,7 +2030,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Reconstruct bf16 values first, then requantize to flashinfer swizzled # scales so mm_mxfp8 can consume them directly. weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) - weight_q, weight_scale = mxfp8_quantize(weight_bf16, is_sf_swizzled_layout=True) + weight_q, weight_scale = mxfp8_quantize( + weight_bf16, is_sf_swizzled_layout=True, backend="cuda" + ) # mm_mxfp8 expects B as [K, N] column-major. layer.weight = Parameter(weight_q.t(), requires_grad=False) @@ -2075,7 +2077,9 @@ def apply( input_padded[:m_actual, :] = input_2d input_2d = input_padded - input_q, input_scale = mxfp8_quantize(input_2d, is_sf_swizzled_layout=True) + input_q, input_scale = mxfp8_quantize( + input_2d, is_sf_swizzled_layout=True, backend="cuda" + ) out = mm_mxfp8( input_q, From 39b6b1d50d2bf3918b938b60062bbcce1193b5d7 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 15:30:59 +0000 Subject: [PATCH 07/14] modelopt_mxfp8: run MoE as BF16 Triton fallback for correctness --- .../srt/layers/quantization/modelopt_quant.py | 218 +++++++----------- 1 file changed, 83 insertions(+), 135 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 4100984836cc..2e832399ec1d 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1951,11 +1951,7 @@ class ModelOptMxfp8LinearMethod(LinearMethodBase): """Linear method for ModelOpt MXFP8 quantization. Loads FP8 weights with uint8 UE8M0 block scales (group_size=32). - Uses flashinfer's mm_mxfp8 for inference. - - During process_weights_after_loading, checkpoint MXFP8 weights are - dequantized to bf16 and re-quantized with flashinfer's swizzled scale - layout required by mm_mxfp8. + Uses cuBLAS torch._scaled_mm for inference. """ BLOCK_K = 32 # MXFP8 group size @@ -1964,19 +1960,6 @@ def __init__(self, quant_config: ModelOptMxfp8Config): super().__init__() self.quant_config = quant_config - @staticmethod - def _dequantize_mxfp8(data: torch.Tensor, scale_u8: torch.Tensor) -> torch.Tensor: - """Dequantize MXFP8 tensor with UE8M0 scales back to bf16.""" - group_size = 32 - m, k = data.shape - n_groups = k // group_size - scales_f32 = torch.pow( - 2.0, scale_u8.to(dtype=torch.float32, device=data.device) - 127.0 - ) - data_f32 = data.to(torch.float32).view(m, n_groups, group_size) - scales_f32 = scales_f32.view(m, n_groups, 1) - return (data_f32 * scales_f32).view(m, k).to(torch.bfloat16) - def create_weights( self, layer: torch.nn.Module, @@ -2024,20 +2007,30 @@ def create_weights( layer.register_parameter("weight_scale", scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - from flashinfer.fp8_quantization import mxfp8_quantize - - # Checkpoint MXFP8 stores raw FP8 data and per-block UE8M0 scales. - # Reconstruct bf16 values first, then requantize to flashinfer swizzled - # scales so mm_mxfp8 can consume them directly. - weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) - weight_q, weight_scale = mxfp8_quantize( - weight_bf16, is_sf_swizzled_layout=True, backend="cuda" - ) - - # mm_mxfp8 expects B as [K, N] column-major. - layer.weight = Parameter(weight_q.t(), requires_grad=False) - layer.weight_scale_inv = Parameter( - weight_scale.view(torch.uint8), requires_grad=False + from sglang.srt.layers.quantization.fp8_utils import ( + prepare_mxfp8_weight_for_cublas, + ) + + # Keep checkpoint-native FP8 + UE8M0 tensors and precompute cuBLAS + # views/layouts once at load time. + layer.weight = Parameter(layer.weight.data, requires_grad=False) + layer.weight_scale_inv = Parameter(layer.weight_scale.data, requires_grad=False) + layer.weight_scale_inv.format_ue8m0 = True + if not hasattr(ModelOptMxfp8LinearMethod, "_debug_scale_log_count"): + ModelOptMxfp8LinearMethod._debug_scale_log_count = 0 + if ModelOptMxfp8LinearMethod._debug_scale_log_count < 8: + s = layer.weight_scale_inv.data + logger.info( + "ModelOpt MXFP8 linear load stats: shape=%s dtype=%s min=%s max=%s mean=%.2f", + tuple(s.shape), + s.dtype, + int(s.min().item()), + int(s.max().item()), + float(s.float().mean().item()), + ) + ModelOptMxfp8LinearMethod._debug_scale_log_count += 1 + layer.weight_t, layer.weight_scale_cublas = prepare_mxfp8_weight_for_cublas( + layer.weight.data, layer.weight_scale_inv.data ) def apply( @@ -2046,60 +2039,30 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from flashinfer import mm_mxfp8 - from flashinfer.fp8_quantization import mxfp8_quantize - - input_tensor = x[0] if isinstance(x, tuple) else x - input_2d = input_tensor.view(-1, input_tensor.shape[-1]) # [M, K] - output_shape = [*input_tensor.shape[:-1], layer.weight.shape[1]] # [..., N] - m_actual = input_2d.shape[0] - k_dim = input_2d.shape[1] - - # Quantize activations to MXFP8 using swizzled scales for mm_mxfp8. - # If input is already pre-quantized, dequantize first when a 2D UE8M0 - # scale tensor is provided. - if isinstance(x, tuple): - input_scale = x[1] - if input_scale.dim() == 2: - input_2d = self._dequantize_mxfp8( - input_2d, - input_scale.view(-1, input_scale.shape[-1]), - ) - - # CUTLASS MXFP8 requires M >= 32. Pad decode/capture micro-batches - # and slice the output back to original M. - if m_actual < 32: - input_padded = torch.zeros( - (32, k_dim), - dtype=input_2d.dtype, - device=input_2d.device, - ) - input_padded[:m_actual, :] = input_2d - input_2d = input_padded - - input_q, input_scale = mxfp8_quantize( - input_2d, is_sf_swizzled_layout=True, backend="cuda" + from sglang.srt.layers.quantization.fp8_utils import ( + cublas_mxfp8_blockscaled_linear, ) - out = mm_mxfp8( - input_q, - layer.weight, - input_scale.view(torch.uint8), - layer.weight_scale_inv, - out_dtype=torch.bfloat16, - ) - out = out[:m_actual, :] + if isinstance(x, tuple): + input_data, input_scale = x + else: + input_data, input_scale = x, None - if bias is not None: - out = out + bias - return out.view(*output_shape) + return cublas_mxfp8_blockscaled_linear( + input=input_data, + weight_t=layer.weight_t, + weight_scale_cublas=layer.weight_scale_cublas, + input_scale=input_scale, + bias=bias, + ) class ModelOptMxfp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt MXFP8 quantization. Loads FP8 expert weights with uint8 UE8M0 block scales. - Uses the existing MXFP8 MoE infrastructure for inference. + For correctness, this path dequantizes expert weights to BF16 at load time + and runs MoE with the Triton unquantized kernel path. """ BLOCK_K = 32 # MXFP8 group size @@ -2183,7 +2146,8 @@ def create_weights( set_weight_attrs(w2_weight_scale, extra_weight_attrs_block) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Rename weight_scale → weight_scale_inv for downstream MXFP8 code + # Rename weight_scale → weight_scale_inv to keep parameter naming + # consistent with existing MoE code paths. layer.w13_weight_scale_inv = Parameter( layer.w13_weight_scale.data, requires_grad=False ) @@ -2192,81 +2156,65 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) - - # Use existing MXFP8 MoE weight preparation infrastructure - if get_moe_runner_backend().is_flashinfer_trtllm(): - from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( - align_mxfp8_moe_weights_for_flashinfer_trtllm, + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True + + # Correctness fallback: convert checkpoint MXFP8 MoE weights to BF16. + # This avoids backend-specific MXFP8 MoE scale-layout mismatches. + def _dequantize_mxfp8_3d( + fp8_weight: torch.Tensor, scale_u8: torch.Tensor + ) -> torch.Tensor: + # fp8_weight: [E, N, K], scale_u8: [E, N, K//32] + group_size = 32 + num_experts, n_dim, k_dim = fp8_weight.shape + assert ( + k_dim % group_size == 0 + ), f"K={k_dim} must be divisible by {group_size} for MXFP8." + scales_f32 = torch.pow( + 2.0, scale_u8.to(dtype=torch.float32, device=fp8_weight.device) - 127.0 + ) + weight_f32 = fp8_weight.to(torch.float32).view( + num_experts, n_dim, k_dim // group_size, group_size + ) + return ( + (weight_f32 * scales_f32.unsqueeze(-1)) + .view(num_experts, n_dim, k_dim) + .to(torch.bfloat16) ) - # Weights are already quantized FP8 from checkpoint — no re-quantization - align_mxfp8_moe_weights_for_flashinfer_trtllm(layer, quantize=False) + layer.w13_weight = Parameter( + _dequantize_mxfp8_3d( + layer.w13_weight.data, layer.w13_weight_scale_inv.data + ), + requires_grad=False, + ) + layer.w2_weight = Parameter( + _dequantize_mxfp8_3d(layer.w2_weight.data, layer.w2_weight_scale_inv.data), + requires_grad=False, + ) def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - if get_moe_runner_backend().is_flashinfer_trtllm(): - self.runner = MoeRunner( - MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config + requested_backend = get_moe_runner_backend() + if requested_backend.is_flashinfer_trtllm(): + logger.warning( + "modelopt_mxfp8 MoE falls back to Triton BF16 execution for " + "correctness; ignoring flashinfer_trtllm runner for this layer." ) - else: - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) def apply( self, layer: torch.nn.Module, dispatch_output: "StandardDispatchOutput", ) -> "CombineInput": - from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( - FlashInferTrtllmFp8MoeQuantInfo, - fused_experts_none_to_flashinfer_trtllm_fp8, - ) - from sglang.srt.layers.moe.topk import TopKOutputChecker - from sglang.srt.layers.moe.utils import RoutingMethodType - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - - if ( - get_moe_runner_backend().is_flashinfer_trtllm() - and TopKOutputChecker.format_is_bypassed(topk_output) - ): - topk_config = topk_output.topk_config - - quant_info = FlashInferTrtllmFp8MoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - global_num_experts=layer.num_experts, - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - intermediate_size=layer.w2_weight.shape[2], - routing_method_type=int( - getattr( - layer, - "routing_method_type", - RoutingMethodType.Renormalize, - ) - ), - block_quant=True, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight_scale_inv=layer.w2_weight_scale_inv, - use_mxfp8=True, - ) - - return fused_experts_none_to_flashinfer_trtllm_fp8( - dispatch_output, quant_info, self.moe_runner_config - ) - - # Fallback: triton path + # Run as unquantized Triton MoE with dequantized BF16 weights. from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo quant_info = TritonMoeQuantInfo( w13_weight=layer.w13_weight, w2_weight=layer.w2_weight, - use_fp8_w8a8=True, - per_channel_quant=False, - w13_scale=layer.w13_weight_scale_inv, - w2_scale=layer.w2_weight_scale_inv, ) return self.runner.run(dispatch_output, quant_info) From 608da34024a642d810aa8f01ca47d8dd0acc0c64 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 15:35:04 +0000 Subject: [PATCH 08/14] mxfp8: add cuBLAS helpers and stabilize flashinfer mx quantize path --- .../moe/moe_runner/flashinfer_trtllm.py | 86 +++++--------- .../srt/layers/quantization/fp8_utils.py | 105 ++++++++++++++++++ 2 files changed, 132 insertions(+), 59 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index f512b3f6742d..875a9dfabf63 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging from dataclasses import dataclass from typing import TYPE_CHECKING, cast @@ -8,8 +7,6 @@ from torch.nn import Module from torch.nn.parameter import Parameter -logger = logging.getLogger(__name__) - from sglang.srt.distributed import get_tp_group from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, @@ -206,12 +203,34 @@ def align_mxfp8_moe_weights_for_flashinfer_trtllm( w2_weight = cast(torch.Tensor, layer.w2_weight) num_experts = w13_weight.shape[0] + def _dequantize_mxfp8_2d(fp8_weight: torch.Tensor, scale_u8: torch.Tensor): + group_size = 32 + n, k = fp8_weight.shape + n_groups = k // group_size + scales_f32 = torch.pow( + 2.0, scale_u8.to(dtype=torch.float32, device=fp8_weight.device) - 127.0 + ) + weight_f32 = fp8_weight.to(torch.float32).view(n, n_groups, group_size) + return (weight_f32 * scales_f32.unsqueeze(-1)).view(n, k).to(torch.bfloat16) + if not quantize: - # Pre-quantized FP8 checkpoint: upcast to BF16 so mxfp8_quantize can - # re-derive properly swizzled 1D scale factors (the checkpoint stores - # raw 2D UE8M0 scales which the trtllm kernel cannot consume directly). - w13_weight = w13_weight.to(torch.bfloat16) - w2_weight = w2_weight.to(torch.bfloat16) + # Pre-quantized FP8 checkpoint stores raw FP8 mantissas plus 2D UE8M0 + # block scales. Reconstruct bf16 first, then requantize to flashinfer's + # swizzled scale layout expected by the TRT-LLM kernel. + w13_scale = cast(torch.Tensor, layer.w13_weight_scale_inv) + w2_scale = cast(torch.Tensor, layer.w2_weight_scale_inv) + w13_weight = torch.stack( + [ + _dequantize_mxfp8_2d(w13_weight[i], w13_scale[i]) + for i in range(num_experts) + ] + ) + w2_weight = torch.stack( + [ + _dequantize_mxfp8_2d(w2_weight[i], w2_scale[i]) + for i in range(num_experts) + ] + ) # Quantize (or re-quantize) to MXFP8 with swizzled scales for weights w13_q_list, w13_s_list = [], [] @@ -361,57 +380,6 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( if routing_method_type == RoutingMethodType.DeepSeekV3 else router_logits ) - logger.info( - "trtllm_fp8_block_scale_moe call:\n" - " fp8_quantization_type=%s\n" - " routing_logits: shape=%s dtype=%s\n" - " routing_bias: %s\n" - " hidden_states (a_q): shape=%s dtype=%s\n" - " hidden_states_scale (a_sf): shape=%s dtype=%s\n" - " gemm1_weights (w13): shape=%s dtype=%s\n" - " gemm1_weights_scale (w13_scale_inv): shape=%s dtype=%s\n" - " gemm2_weights (w2): shape=%s dtype=%s\n" - " gemm2_weights_scale (w2_scale_inv): shape=%s dtype=%s\n" - " num_experts=%s top_k=%s n_group=%s topk_group=%s\n" - " intermediate_size=%s local_expert_offset=%s local_num_experts=%s\n" - " routed_scaling_factor=%s routing_method_type=%s\n" - " use_shuffled_weight=%s tune_max_num_tokens=%s", - fp8_quant_type, - _rl.shape, - _rl.dtype, - ( - f"shape={correction_bias.shape} dtype={correction_bias.dtype}" - if correction_bias is not None - else "None" - ), - a_q.shape, - a_q.dtype, - a_sf.shape, - a_sf.dtype, - quant_info.w13_weight.shape, - quant_info.w13_weight.dtype, - quant_info.w13_weight_scale_inv.shape, - quant_info.w13_weight_scale_inv.dtype, - quant_info.w2_weight.shape, - quant_info.w2_weight.dtype, - quant_info.w2_weight_scale_inv.shape, - quant_info.w2_weight_scale_inv.dtype, - quant_info.global_num_experts, - topk_config.top_k, - topk_config.num_expert_group if topk_config.num_expert_group else 0, - topk_config.topk_group if topk_config.topk_group else 0, - quant_info.intermediate_size, - quant_info.local_expert_offset, - quant_info.local_num_experts, - ( - runner_config.routed_scaling_factor - if runner_config.routed_scaling_factor is not None - else 1.0 - ), - routing_method_type, - use_shuffled_weight, - next_power_of_2(a_q.shape[0]), - ) output = trtllm_fp8_block_scale_moe( routing_logits=_rl, routing_bias=correction_bias, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3b9616e2798f..31b3db73b5ce 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -654,6 +654,111 @@ def _pack_mxfp8_scales(scale_u8: torch.Tensor) -> torch.Tensor: return packed.view(1, scale_m, scale_k, 2, 256) +def _interleave_mxfp8_scales_for_cublas( + scale_u8: torch.Tensor, m_padded: int, k: int +) -> torch.Tensor: + """Convert natural MXFP8 scales (M, K//32) to cuBLAS interleaved layout.""" + k_groups = k // 32 + return ( + scale_u8.view(m_padded // 128, 4, 32, k // 128, 4) + .permute(0, 3, 2, 1, 4) + .contiguous() + .view(m_padded, k_groups) + .view(torch.float8_e8m0fnu) + ) + + +def prepare_mxfp8_weight_for_cublas( + weight: torch.Tensor, weight_scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pre-compute (K, N) weight view and interleaved scales for torch._scaled_mm.""" + n, k = weight.shape + assert n % 128 == 0, f"N={n} must be divisible by 128 for MXFP8" + assert k % 128 == 0, f"K={k} must be divisible by 128 for MXFP8" + # Keep non-contiguous transpose view for column-major B expected by cuBLAS. + weight_t = weight.t() + weight_scale_cublas = _interleave_mxfp8_scales_for_cublas(weight_scale, n, k) + return weight_t, weight_scale_cublas + + +def cublas_mxfp8_blockscaled_linear( + input: torch.Tensor, + weight_t: torch.Tensor, + weight_scale_cublas: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """MXFP8 dense linear via cuBLAS torch._scaled_mm.""" + if not (_is_cuda and is_sm100_supported()): + raise RuntimeError("MXFP8 cuBLAS linear requires Blackwell GPUs (SM100+).") + + if not hasattr(cublas_mxfp8_blockscaled_linear, "_logged"): + logger.info("Using cuBLAS MXFP8 dense linear (torch._scaled_mm)") + cublas_mxfp8_blockscaled_linear._logged = True + + input_2d = input.view(-1, input.shape[-1]).contiguous() + k, n = weight_t.shape + output_shape = [*input.shape[:-1], n] + m = input_2d.shape[0] + assert ( + input_2d.shape[1] == k + ), f"K mismatch: input {input_2d.shape[1]} vs weight {k}" + + if input_scale is None: + q_input, x_scale_u8 = mxfp8_group_quantize(input_2d) + else: + q_input = input_2d + x_scale_u8 = input_scale + assert x_scale_u8.dtype == torch.uint8, "MXFP8 input_scale must be UE8M0 uint8." + assert x_scale_u8.shape == (m, k // 32) + + if output_dtype is None: + if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32): + output_dtype = input_2d.dtype + else: + output_dtype = torch.bfloat16 + + if m % 128 != 0: + m_padded = ceil_div(m, 128) * 128 + pad_rows = m_padded - m + q_input = torch.cat( + [ + q_input, + torch.zeros((pad_rows, k), device=q_input.device, dtype=q_input.dtype), + ], + dim=0, + ) + x_scale_u8 = torch.cat( + [ + x_scale_u8, + torch.full( + (pad_rows, k // 32), + 127, + device=x_scale_u8.device, + dtype=x_scale_u8.dtype, + ), + ], + dim=0, + ) + else: + m_padded = m + + sa = _interleave_mxfp8_scales_for_cublas(x_scale_u8, m_padded, k) + output = torch._scaled_mm( + q_input, + weight_t, + scale_a=sa, + scale_b=weight_scale_cublas, + out_dtype=output_dtype, + ) + + output = output[:m, :] + if bias is not None: + output += bias + return output.view(*output_shape) + + def triton_mxfp8_blockscaled_linear( input: torch.Tensor, weight: torch.Tensor, From 34f54cd6bcd7eb9d2de270c165cc751966f52c27 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 15:52:20 +0000 Subject: [PATCH 09/14] WIP use mm_mxfp8 Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- .../srt/layers/quantization/modelopt_quant.py | 108 ++++++++++++------ 1 file changed, 71 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 2e832399ec1d..e58c049c11f1 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1951,7 +1951,11 @@ class ModelOptMxfp8LinearMethod(LinearMethodBase): """Linear method for ModelOpt MXFP8 quantization. Loads FP8 weights with uint8 UE8M0 block scales (group_size=32). - Uses cuBLAS torch._scaled_mm for inference. + Uses flashinfer's mm_mxfp8 for inference. + + During process_weights_after_loading, checkpoint MXFP8 weights are + dequantized to bf16 and re-quantized with flashinfer's swizzled scale + layout required by mm_mxfp8. """ BLOCK_K = 32 # MXFP8 group size @@ -1960,6 +1964,19 @@ def __init__(self, quant_config: ModelOptMxfp8Config): super().__init__() self.quant_config = quant_config + @staticmethod + def _dequantize_mxfp8(data: torch.Tensor, scale_u8: torch.Tensor) -> torch.Tensor: + """Dequantize MXFP8 tensor with UE8M0 scales back to bf16.""" + group_size = 32 + m, k = data.shape + n_groups = k // group_size + scales_f32 = torch.pow( + 2.0, scale_u8.to(dtype=torch.float32, device=data.device) - 127.0 + ) + data_f32 = data.to(torch.float32).view(m, n_groups, group_size) + scales_f32 = scales_f32.view(m, n_groups, 1) + return (data_f32 * scales_f32).view(m, k).to(torch.bfloat16) + def create_weights( self, layer: torch.nn.Module, @@ -2007,30 +2024,18 @@ def create_weights( layer.register_parameter("weight_scale", scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - from sglang.srt.layers.quantization.fp8_utils import ( - prepare_mxfp8_weight_for_cublas, - ) + from flashinfer.fp8_quantization import mxfp8_quantize - # Keep checkpoint-native FP8 + UE8M0 tensors and precompute cuBLAS - # views/layouts once at load time. - layer.weight = Parameter(layer.weight.data, requires_grad=False) - layer.weight_scale_inv = Parameter(layer.weight_scale.data, requires_grad=False) - layer.weight_scale_inv.format_ue8m0 = True - if not hasattr(ModelOptMxfp8LinearMethod, "_debug_scale_log_count"): - ModelOptMxfp8LinearMethod._debug_scale_log_count = 0 - if ModelOptMxfp8LinearMethod._debug_scale_log_count < 8: - s = layer.weight_scale_inv.data - logger.info( - "ModelOpt MXFP8 linear load stats: shape=%s dtype=%s min=%s max=%s mean=%.2f", - tuple(s.shape), - s.dtype, - int(s.min().item()), - int(s.max().item()), - float(s.float().mean().item()), - ) - ModelOptMxfp8LinearMethod._debug_scale_log_count += 1 - layer.weight_t, layer.weight_scale_cublas = prepare_mxfp8_weight_for_cublas( - layer.weight.data, layer.weight_scale_inv.data + # Checkpoint MXFP8 stores raw FP8 data and per-block UE8M0 scales. + # Reconstruct bf16 values first, then requantize to flashinfer swizzled + # scales so mm_mxfp8 can consume them directly. + weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) + weight_q, weight_scale = mxfp8_quantize(weight_bf16, is_sf_swizzled_layout=True) + + # mm_mxfp8 expects B as [K, N] column-major. + layer.weight = Parameter(weight_q.t(), requires_grad=False) + layer.weight_scale_inv = Parameter( + weight_scale.view(torch.uint8), requires_grad=False ) def apply( @@ -2039,22 +2044,51 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from sglang.srt.layers.quantization.fp8_utils import ( - cublas_mxfp8_blockscaled_linear, - ) - + from flashinfer import mm_mxfp8 + from flashinfer.fp8_quantization import mxfp8_quantize + + input_tensor = x[0] if isinstance(x, tuple) else x + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) # [M, K] + output_shape = [*input_tensor.shape[:-1], layer.weight.shape[1]] # [..., N] + m_actual = input_2d.shape[0] + k_dim = input_2d.shape[1] + + # Quantize activations to MXFP8 using swizzled scales for mm_mxfp8. + # If input is already pre-quantized, dequantize first when a 2D UE8M0 + # scale tensor is provided. if isinstance(x, tuple): - input_data, input_scale = x - else: - input_data, input_scale = x, None + input_scale = x[1] + if input_scale.dim() == 2: + input_2d = self._dequantize_mxfp8( + input_2d, + input_scale.view(-1, input_scale.shape[-1]), + ) - return cublas_mxfp8_blockscaled_linear( - input=input_data, - weight_t=layer.weight_t, - weight_scale_cublas=layer.weight_scale_cublas, - input_scale=input_scale, - bias=bias, + # CUTLASS MXFP8 requires M >= 32. Pad decode/capture micro-batches + # and slice the output back to original M. + if m_actual < 32: + input_padded = torch.zeros( + (32, k_dim), + dtype=input_2d.dtype, + device=input_2d.device, + ) + input_padded[:m_actual, :] = input_2d + input_2d = input_padded + + input_q, input_scale = mxfp8_quantize(input_2d, is_sf_swizzled_layout=True) + + out = mm_mxfp8( + input_q, + layer.weight, + input_scale.view(torch.uint8), + layer.weight_scale_inv, + out_dtype=torch.bfloat16, ) + out = out[:m_actual, :] + + if bias is not None: + out = out + bias + return out.view(*output_shape) class ModelOptMxfp8MoEMethod(FusedMoEMethodBase): From 83088026c076523bd47e3f18fde791af90de1afe Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 16:20:11 +0000 Subject: [PATCH 10/14] WIPremove hack Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- .../srt/layers/quantization/modelopt_quant.py | 75 +++++++------------ 1 file changed, 25 insertions(+), 50 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index e58c049c11f1..e945b9954cd4 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2030,7 +2030,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Reconstruct bf16 values first, then requantize to flashinfer swizzled # scales so mm_mxfp8 can consume them directly. weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) - weight_q, weight_scale = mxfp8_quantize(weight_bf16, is_sf_swizzled_layout=True) + weight_q, weight_scale = mxfp8_quantize( + weight_bf16, is_sf_swizzled_layout=True, backend="cuda" + ) # mm_mxfp8 expects B as [K, N] column-major. layer.weight = Parameter(weight_q.t(), requires_grad=False) @@ -2075,7 +2077,9 @@ def apply( input_padded[:m_actual, :] = input_2d input_2d = input_padded - input_q, input_scale = mxfp8_quantize(input_2d, is_sf_swizzled_layout=True) + input_q, input_scale = mxfp8_quantize( + input_2d, is_sf_swizzled_layout=True, backend="cuda" + ) out = mm_mxfp8( input_q, @@ -2094,9 +2098,8 @@ def apply( class ModelOptMxfp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt MXFP8 quantization. - Loads FP8 expert weights with uint8 UE8M0 block scales. - For correctness, this path dequantizes expert weights to BF16 at load time - and runs MoE with the Triton unquantized kernel path. + Loads FP8 expert weights with uint8 UE8M0 block scales (group_size=32). + Converts UE8M0 scales to float32 and runs quantized Triton block-FP8 MoE. """ BLOCK_K = 32 # MXFP8 group size @@ -2180,50 +2183,23 @@ def create_weights( set_weight_attrs(w2_weight_scale, extra_weight_attrs_block) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Rename weight_scale → weight_scale_inv to keep parameter naming - # consistent with existing MoE code paths. - layer.w13_weight_scale_inv = Parameter( - layer.w13_weight_scale.data, requires_grad=False - ) - layer.w2_weight_scale_inv = Parameter( - layer.w2_weight_scale.data, requires_grad=False - ) + # Keep FP8 weights as-is; convert UE8M0 uint8 block scales to float32 + # so the Triton block-FP8 kernel can use them directly. layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) - layer.w13_weight_scale_inv.format_ue8m0 = True - layer.w2_weight_scale_inv.format_ue8m0 = True - - # Correctness fallback: convert checkpoint MXFP8 MoE weights to BF16. - # This avoids backend-specific MXFP8 MoE scale-layout mismatches. - def _dequantize_mxfp8_3d( - fp8_weight: torch.Tensor, scale_u8: torch.Tensor - ) -> torch.Tensor: - # fp8_weight: [E, N, K], scale_u8: [E, N, K//32] - group_size = 32 - num_experts, n_dim, k_dim = fp8_weight.shape - assert ( - k_dim % group_size == 0 - ), f"K={k_dim} must be divisible by {group_size} for MXFP8." - scales_f32 = torch.pow( - 2.0, scale_u8.to(dtype=torch.float32, device=fp8_weight.device) - 127.0 - ) - weight_f32 = fp8_weight.to(torch.float32).view( - num_experts, n_dim, k_dim // group_size, group_size - ) - return ( - (weight_f32 * scales_f32.unsqueeze(-1)) - .view(num_experts, n_dim, k_dim) - .to(torch.bfloat16) - ) - layer.w13_weight = Parameter( - _dequantize_mxfp8_3d( - layer.w13_weight.data, layer.w13_weight_scale_inv.data + layer.w13_weight_scale_inv = Parameter( + torch.pow( + 2.0, + layer.w13_weight_scale.data.to(torch.float32) - 127.0, ), requires_grad=False, ) - layer.w2_weight = Parameter( - _dequantize_mxfp8_3d(layer.w2_weight.data, layer.w2_weight_scale_inv.data), + layer.w2_weight_scale_inv = Parameter( + torch.pow( + 2.0, + layer.w2_weight_scale.data.to(torch.float32) - 127.0, + ), requires_grad=False, ) @@ -2231,12 +2207,6 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - requested_backend = get_moe_runner_backend() - if requested_backend.is_flashinfer_trtllm(): - logger.warning( - "modelopt_mxfp8 MoE falls back to Triton BF16 execution for " - "correctness; ignoring flashinfer_trtllm runner for this layer." - ) self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) def apply( @@ -2244,11 +2214,16 @@ def apply( layer: torch.nn.Module, dispatch_output: "StandardDispatchOutput", ) -> "CombineInput": - # Run as unquantized Triton MoE with dequantized BF16 weights. + # Run quantized Triton MoE with FP8 weights and block scales. from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo quant_info = TritonMoeQuantInfo( w13_weight=layer.w13_weight, w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + per_channel_quant=False, + w13_scale=layer.w13_weight_scale_inv, + w2_scale=layer.w2_weight_scale_inv, + block_shape=[1, 32], ) return self.runner.run(dispatch_output, quant_info) From cb270eb6957fd007d9de7e4e4041506a5f8c3cb4 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 16:37:02 +0000 Subject: [PATCH 11/14] WIPmoe flashinfer cutlass Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- .../srt/layers/quantization/modelopt_quant.py | 171 ++++++++++++++++-- 1 file changed, 154 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index e945b9954cd4..00d270ae5773 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2099,7 +2099,8 @@ class ModelOptMxfp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt MXFP8 quantization. Loads FP8 expert weights with uint8 UE8M0 block scales (group_size=32). - Converts UE8M0 scales to float32 and runs quantized Triton block-FP8 MoE. + Supports CUTLASS (native MXFP8) and Triton (block-FP8 with float scales) + MoE runner backends. """ BLOCK_K = 32 # MXFP8 group size @@ -2182,39 +2183,175 @@ def create_weights( set_weight_attrs(w13_weight_scale, extra_weight_attrs_block) set_weight_attrs(w2_weight_scale, extra_weight_attrs_block) + @staticmethod + def _swizzle_mxfp8_scales( + weight_shape: tuple, scale: torch.Tensor + ) -> torch.Tensor: + """Swizzle uint8 UE8M0 scales for CUTLASS MXFP8 MoE kernel.""" + from triton_kernels.tensor import convert_layout, wrap_torch_tensor + from triton_kernels.tensor_details import layout + + num_experts, m, k = weight_shape + aligned_m = ((m + 127) // 128) * 128 + scale = scale.view(num_experts, aligned_m, k // 32) + num_warps = 8 + scale_layout, scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps + ) + ) + scale = scale.transpose(-2, -1) + scale = convert_layout( + wrap_torch_tensor(scale), scale_layout, **scale_layout_opts + ) + scale = scale.data.view(num_experts, aligned_m, k // 32) + return scale + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Keep FP8 weights as-is; convert UE8M0 uint8 block scales to float32 - # so the Triton block-FP8 kernel can use them directly. layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) - layer.w13_weight_scale_inv = Parameter( - torch.pow( - 2.0, - layer.w13_weight_scale.data.to(torch.float32) - 127.0, - ), - requires_grad=False, + if get_moe_runner_backend().is_cutlass(): + # CUTLASS MXFP8: keep uint8 scales, swizzle for kernel layout + layer.w13_weight_scale_inv = Parameter( + self._swizzle_mxfp8_scales( + layer.w13_weight.data.shape, layer.w13_weight_scale.data + ), + requires_grad=False, + ) + layer.w2_weight_scale_inv = Parameter( + self._swizzle_mxfp8_scales( + layer.w2_weight.data.shape, layer.w2_weight_scale.data + ), + requires_grad=False, + ) + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True + else: + # Triton block-FP8: convert UE8M0 uint8 → float32 scales + layer.w13_weight_scale_inv = Parameter( + torch.pow( + 2.0, + layer.w13_weight_scale.data.to(torch.float32) - 127.0, + ), + requires_grad=False, + ) + layer.w2_weight_scale_inv = Parameter( + torch.pow( + 2.0, + layer.w2_weight_scale.data.to(torch.float32) - 127.0, + ), + requires_grad=False, + ) + + def _ensure_cutlass_buffers(self, layer: torch.nn.Module) -> None: + if getattr(self, "_cutlass_buffers_ready", False): + return + device = layer.w13_weight.device + num_experts = layer.w13_weight.shape[0] + hidden_size = layer.w2_weight.shape[1] + intermediate_size_per_partition = layer.w2_weight.shape[2] + + self.ab_strides1 = torch.full( + (num_experts,), hidden_size, device=device, dtype=torch.int64 ) - layer.w2_weight_scale_inv = Parameter( - torch.pow( - 2.0, - layer.w2_weight_scale.data.to(torch.float32) - 127.0, - ), - requires_grad=False, + self.c_strides1 = torch.full( + (num_experts,), + 2 * intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + self.ab_strides2 = torch.full( + (num_experts,), + intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + self.c_strides2 = torch.full( + (num_experts,), hidden_size, device=device, dtype=torch.int64 + ) + self.workspace = torch.empty(90000, device=device, dtype=torch.uint8) + self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) + self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) + self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) + self.a_scales_ptr = torch.empty( + num_experts, device=device, dtype=torch.int64 + ) + self.b_scales_ptr = torch.empty( + num_experts, device=device, dtype=torch.int64 + ) + self.expert_offsets = torch.empty( + num_experts + 1, device=device, dtype=torch.int32 ) + self.problem_sizes1 = torch.empty( + num_experts, 3, device=device, dtype=torch.int32 + ) + self.problem_sizes2 = torch.empty( + num_experts, 3, device=device, dtype=torch.int32 + ) + self._cutlass_buffers_ready = True def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + if get_moe_runner_backend().is_cutlass(): + # CUTLASS path is called directly in apply(), no MoeRunner needed. + self._ensure_cutlass_buffers(layer) + else: + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) def apply( self, layer: torch.nn.Module, dispatch_output: "StandardDispatchOutput", ) -> "CombineInput": - # Run quantized Triton MoE with FP8 weights and block scales. + if get_moe_runner_backend().is_cutlass(): + from sglang.srt.distributed import get_tp_group + from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, + ) + from sglang.srt.layers.dp_attention import is_allocation_symmetric + from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_weights, topk_ids, _ = dispatch_output.topk_output + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + symm_output = torch.empty_like(x) + + output = cutlass_fused_experts_fp8( + x, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale_inv.transpose(1, 2), + layer.w2_weight_scale_inv.transpose(1, 2), + topk_weights, + topk_ids, + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + self.workspace, + self.a_ptr, + self.b_ptr, + self.out_ptr, + self.a_scales_ptr, + self.b_scales_ptr, + self.expert_offsets, + self.problem_sizes1, + self.problem_sizes2, + use_fp8_blockscale=True, + use_mxfp8=True, + output=symm_output, + enable_es=(True, True), + ) + return StandardCombineInput(hidden_states=output) + + # Fallback: Triton block-FP8 with float scales from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo quant_info = TritonMoeQuantInfo( From d0368761bd655d3683c798e264c5f40788c33c8b Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Sun, 8 Feb 2026 16:44:57 +0000 Subject: [PATCH 12/14] mxfp8: remove non-working flashinfer_trtllm MoE backend The flashinfer_trtllm MXFP8 MoE path was untested and non-functional. Remove all related code: align_mxfp8_moe_weights_for_flashinfer_trtllm, use_mxfp8 field on FlashInferTrtllmFp8MoeQuantInfo, MXFP8 branches in fused_experts_none_to_flashinfer_trtllm_fp8, and auto-selection of flashinfer_trtllm for modelopt_mxfp8. Working backends are cutlass and triton. Co-Authored-By: Claude Opus 4.6 --- .../srt/layers/moe/fused_moe_triton/layer.py | 19 +-- .../moe/moe_runner/flashinfer_trtllm.py | 140 ++---------------- python/sglang/srt/layers/quantization/fp8.py | 10 -- python/sglang/srt/server_args.py | 6 +- 4 files changed, 13 insertions(+), 162 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index dddd2df3c564..2b26ae8f87f7 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -58,10 +58,7 @@ CompressedTensorsMxInt4MoEMethod, ) from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod -from sglang.srt.layers.quantization.modelopt_quant import ( - ModelOptMxfp8MoEMethod, - ModelOptNvFp4FusedMoEMethod, -) +from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.server_args import get_global_server_args @@ -707,7 +704,6 @@ def _weight_loader_impl( # Flashinfer assumes w31 format for w13_weight. Same for the scales. if self.use_flashinfer_trtllm_moe and ( isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) - or isinstance(self.quant_method, ModelOptMxfp8MoEMethod) or isinstance(self.quant_method, Fp8MoEMethod) or isinstance(self.quant_method, UnquantizedFusedMoEMethod) or isinstance(self.quant_method, CompressedTensorsMxInt4MoEMethod) @@ -773,19 +769,6 @@ def _weight_loader_impl( if "ModelOpt" in self.quant_method.__class__.__name__: is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) - is_mxfp8_variant = isinstance(self.quant_method, ModelOptMxfp8MoEMethod) - - if is_mxfp8_variant: - # MXFP8: weight_scale is block scale (uint8 UE8M0), not per-tensor - if "weight_scale" in weight_name or "weight" in weight_name: - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank, - ) - return # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor per_tensor_conditions = ( diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 875a9dfabf63..3aa20b17a17a 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -181,104 +181,6 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: ) -def align_mxfp8_moe_weights_for_flashinfer_trtllm( - layer: Module, quantize: bool = True -) -> None: - """Prepare MXFP8 MoE weights/scales for FlashInfer TRT-LLM kernels. - - For the trtllm MXFP8 path, weights must be: - 1. Quantized to MXFP8 using flashinfer's mxfp8_quantize (swizzled scales) - 2. Rows reordered for gated activation (w13 only) - 3. Shuffled for transposed MMA output - - Args: - layer: The MoE layer to process. - quantize: If True, quantize BF16 weights to MXFP8. - If False, assume weights are already MXFP8 quantized. - """ - from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a - from flashinfer.fp8_quantization import mxfp8_quantize - - w13_weight = cast(torch.Tensor, layer.w13_weight) - w2_weight = cast(torch.Tensor, layer.w2_weight) - num_experts = w13_weight.shape[0] - - def _dequantize_mxfp8_2d(fp8_weight: torch.Tensor, scale_u8: torch.Tensor): - group_size = 32 - n, k = fp8_weight.shape - n_groups = k // group_size - scales_f32 = torch.pow( - 2.0, scale_u8.to(dtype=torch.float32, device=fp8_weight.device) - 127.0 - ) - weight_f32 = fp8_weight.to(torch.float32).view(n, n_groups, group_size) - return (weight_f32 * scales_f32.unsqueeze(-1)).view(n, k).to(torch.bfloat16) - - if not quantize: - # Pre-quantized FP8 checkpoint stores raw FP8 mantissas plus 2D UE8M0 - # block scales. Reconstruct bf16 first, then requantize to flashinfer's - # swizzled scale layout expected by the TRT-LLM kernel. - w13_scale = cast(torch.Tensor, layer.w13_weight_scale_inv) - w2_scale = cast(torch.Tensor, layer.w2_weight_scale_inv) - w13_weight = torch.stack( - [ - _dequantize_mxfp8_2d(w13_weight[i], w13_scale[i]) - for i in range(num_experts) - ] - ) - w2_weight = torch.stack( - [ - _dequantize_mxfp8_2d(w2_weight[i], w2_scale[i]) - for i in range(num_experts) - ] - ) - - # Quantize (or re-quantize) to MXFP8 with swizzled scales for weights - w13_q_list, w13_s_list = [], [] - w2_q_list, w2_s_list = [], [] - for i in range(num_experts): - w13_q, w13_s = mxfp8_quantize( - w13_weight[i], is_sf_swizzled_layout=True, backend="cuda" - ) - w13_q_list.append(w13_q) - w13_s_list.append(w13_s.view(torch.uint8)) - - w2_q, w2_s = mxfp8_quantize( - w2_weight[i], is_sf_swizzled_layout=True, backend="cuda" - ) - w2_q_list.append(w2_q) - w2_s_list.append(w2_s.view(torch.uint8)) - - w13_weight = torch.stack(w13_q_list) - w13_scales = torch.stack(w13_s_list) - w2_weight = torch.stack(w2_q_list) - w2_scales = torch.stack(w2_s_list) - - # Reorder rows for gated activation (w13 only — interleaves gate/up halves) - w13_interleaved = torch.stack( - [reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)] - ).reshape_as(w13_weight) - - # Shuffle weights for transposed MMA output (both w13 and w2) - epilogue_tile_m = 128 - w13_shuffled = torch.stack( - [ - shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m) - for i in range(num_experts) - ] - ).view(torch.float8_e4m3fn) - w2_shuffled = torch.stack( - [ - shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m) - for i in range(num_experts) - ] - ).view(torch.float8_e4m3fn) - - layer.w13_weight = Parameter(w13_shuffled, requires_grad=False) - layer.w2_weight = Parameter(w2_shuffled, requires_grad=False) - layer.w13_weight_scale_inv = Parameter(w13_scales, requires_grad=False) - layer.w2_weight_scale_inv = Parameter(w2_scales, requires_grad=False) - - @dataclass class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): """Quantization payload consumed by FlashInfer TRT-LLM FP8 MoE kernels.""" @@ -308,9 +210,6 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): output2_scales_scalar: torch.Tensor | None = None use_routing_scales_on_input: bool = False - # MXFP8 path - use_mxfp8: bool = False - def fused_experts_none_to_flashinfer_trtllm_fp8( dispatch_output: StandardDispatchOutput, @@ -318,7 +217,6 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( runner_config: MoeRunnerConfig, ) -> StandardCombineInput: from flashinfer.fused_moe import ( - Fp8QuantizationType, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, ) @@ -345,28 +243,12 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( routing_method_type = quant_info.routing_method_type if quant_info.block_quant: + assert quant_info.weight_block_k is not None assert quant_info.w13_weight_scale_inv is not None assert quant_info.w2_weight_scale_inv is not None - if quant_info.use_mxfp8: - # MXFP8 path: quantize activations with flashinfer's mxfp8_quantize - from flashinfer.fp8_quantization import mxfp8_quantize - - a_q, a_sf = mxfp8_quantize( - hidden_states, is_sf_swizzled_layout=False, backend="cuda" - ) - a_sf = a_sf.view(torch.uint8).reshape(hidden_states.shape[0], -1) - fp8_quant_type = Fp8QuantizationType.MxFp8 - use_shuffled_weight = True - else: - # DeepSeek FP8 block scale path - assert quant_info.weight_block_k is not None - a_q, a_sf = per_token_group_quant_fp8( - hidden_states, quant_info.weight_block_k - ) - a_sf = a_sf.t().contiguous() - fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 - use_shuffled_weight = False + a_q, a_sf = per_token_group_quant_fp8(hidden_states, quant_info.weight_block_k) + a_sf_t = a_sf.t().contiguous() with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() @@ -375,16 +257,15 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( # It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # so we put the whole function under the ``use_symmetric_memory`` context manager. # If the bug is fixed, we can only put the output tensor allocation under the context manager. - _rl = ( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ) output = trtllm_fp8_block_scale_moe( - routing_logits=_rl, + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), routing_bias=correction_bias, hidden_states=a_q, - hidden_states_scale=a_sf, + hidden_states_scale=a_sf_t, gemm1_weights=quant_info.w13_weight, gemm1_weights_scale=quant_info.w13_weight_scale_inv, gemm2_weights=quant_info.w2_weight, @@ -404,9 +285,8 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( else 1.0 ), routing_method_type=routing_method_type, - use_shuffled_weight=use_shuffled_weight, + use_shuffled_weight=False, tune_max_num_tokens=next_power_of_2(a_q.shape[0]), - fp8_quantization_type=fp8_quant_type, ) else: assert quant_info.w13_input_scale is not None diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 0d450afdf08a..cb3ca7e0d856 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -972,15 +972,6 @@ def _process_mxfp8_moe_weights(self, layer: Module, quantize: bool = True) -> No if not (_is_cuda and is_sm100_supported()): raise RuntimeError("MXFP8 MoE quantization requires SM100.") - # For trtllm backend, use flashinfer-native MXFP8 weight preparation - if get_moe_runner_backend().is_flashinfer_trtllm(): - from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( - align_mxfp8_moe_weights_for_flashinfer_trtllm, - ) - - align_mxfp8_moe_weights_for_flashinfer_trtllm(layer, quantize=quantize) - return - def _quantize_and_swizzle_with_cutlass_es_kernel(weight: torch.Tensor): from sgl_kernel import es_sm100_mxfp8_blockscaled_grouped_quant @@ -1508,7 +1499,6 @@ def apply( if not self.block_quant else None ), - use_mxfp8=self.use_mxfp8, ) elif self.runner.runner_backend.is_triton(): quant_info = TritonMoeQuantInfo( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b851371f9b02..47a6b05bc851 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1298,8 +1298,7 @@ def _handle_model_specific_adjustments(self): if ( self.moe_a2a_backend == "none" and self.moe_runner_backend == "auto" - and self.quantization - in ["fp8", "modelopt_fp8", "modelopt_mxfp8", "modelopt_fp4"] + and self.quantization in ["fp8", "modelopt_fp8", "modelopt_fp4"] ): self.moe_runner_backend = "flashinfer_trtllm" logger.info( @@ -1458,7 +1457,7 @@ def _handle_model_specific_adjustments(self): "intel_xpu", }, f"fa3, aiter, triton, trtllm_mha or intel_xpu is required for Llama4 model but got {self.attention_backend}" if is_sm100_supported() and self.moe_runner_backend == "auto": - if self.quantization in {"fp8", "modelopt_fp8", "modelopt_mxfp8"}: + if self.quantization in {"fp8", "modelopt_fp8"}: self.moe_runner_backend = "flashinfer_trtllm" logger.info( "Use flashinfer_trtllm as MoE runner backend on SM100 for Llama4" @@ -2080,7 +2079,6 @@ def _handle_moe_kernel_config(self): "modelopt_fp4", "fp8", "modelopt_fp8", - "modelopt_mxfp8", "compressed-tensors", None, ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', 'compressed-tensors', or bfloat16 (None)." From 6c41d920644bfbb9f3356bcd931a217ed6520c5a Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:06:26 +0000 Subject: [PATCH 13/14] WIP Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- .../srt/layers/moe/fused_moe_triton/layer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 2b26ae8f87f7..36b491985124 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -58,7 +58,10 @@ CompressedTensorsMxInt4MoEMethod, ) from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod -from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod +from sglang.srt.layers.quantization.modelopt_quant import ( + ModelOptMxfp8MoEMethod, + ModelOptNvFp4FusedMoEMethod, +) from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.server_args import get_global_server_args @@ -769,6 +772,19 @@ def _weight_loader_impl( if "ModelOpt" in self.quant_method.__class__.__name__: is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + is_mxfp8_variant = isinstance(self.quant_method, ModelOptMxfp8MoEMethod) + + if is_mxfp8_variant: + # MXFP8: weight_scale is block scale (uint8 UE8M0), not per-tensor + if "weight_scale" in weight_name or "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + return # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor per_tensor_conditions = ( From ec9ff1c80cdfae00f2fe1ee167f2b83d260b06ac Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:26:47 +0000 Subject: [PATCH 14/14] mxfp8: cleanup dead cuBLAS code, dedup dequantize, extract workspace constant - Remove unused cublas_mxfp8_blockscaled_linear, prepare_mxfp8_weight_for_cublas, and _interleave_mxfp8_scales_for_cublas from fp8_utils.py - Extract _dequantize_mxfp8 to shared dequantize_mxfp8() in fp8_utils.py - Add warning for unrecognized MXFP8 weight names in MoE weight loader - Extract magic 90000 workspace size to CUTLASS_MOE_WORKSPACE_BYTES in cutlass_moe.py --- python/sglang/srt/layers/moe/cutlass_moe.py | 4 + .../srt/layers/moe/fused_moe_triton/layer.py | 5 + python/sglang/srt/layers/quantization/fp8.py | 5 +- .../srt/layers/quantization/fp8_utils.py | 116 +++--------------- .../srt/layers/quantization/modelopt_quant.py | 36 ++---- 5 files changed, 42 insertions(+), 124 deletions(-) diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 9fbc2764daec..0f4b417cae44 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -7,6 +7,10 @@ from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams from sglang.srt.utils import is_cuda, is_sm90_supported, is_sm100_supported +# Workspace size required by CUTLASS grouped GEMM kernels (bytes). +# Used by both FP8 and MXFP8 MoE paths. +CUTLASS_MOE_WORKSPACE_BYTES = 90000 + _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import ( diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 36b491985124..61f4e5f276ce 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -784,6 +784,11 @@ def _weight_loader_impl( expert_data=expert_data, tp_rank=tp_rank, ) + else: + logger.warning( + "MXFP8 MoE: ignoring unrecognized weight %s", + weight_name, + ) return # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index cb3ca7e0d856..43c5d8874335 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -21,6 +21,7 @@ ) from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.cutlass_moe import CUTLASS_MOE_WORKSPACE_BYTES from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( FlashInferTrtllmFp8MoeQuantInfo, @@ -1553,7 +1554,9 @@ def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None: self.c_strides2 = torch.full( (num_experts,), hidden_size, device=device, dtype=torch.int64 ) - self.workspace = torch.empty(90000, device=device, dtype=torch.uint8) + self.workspace = torch.empty( + CUTLASS_MOE_WORKSPACE_BYTES, device=device, dtype=torch.uint8 + ) self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 31b3db73b5ce..04f4c24cf8e4 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -654,109 +654,29 @@ def _pack_mxfp8_scales(scale_u8: torch.Tensor) -> torch.Tensor: return packed.view(1, scale_m, scale_k, 2, 256) -def _interleave_mxfp8_scales_for_cublas( - scale_u8: torch.Tensor, m_padded: int, k: int +def dequantize_mxfp8( + data: torch.Tensor, scale_u8: torch.Tensor, group_size: int = 32 ) -> torch.Tensor: - """Convert natural MXFP8 scales (M, K//32) to cuBLAS interleaved layout.""" - k_groups = k // 32 - return ( - scale_u8.view(m_padded // 128, 4, 32, k // 128, 4) - .permute(0, 3, 2, 1, 4) - .contiguous() - .view(m_padded, k_groups) - .view(torch.float8_e8m0fnu) - ) - - -def prepare_mxfp8_weight_for_cublas( - weight: torch.Tensor, weight_scale: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - """Pre-compute (K, N) weight view and interleaved scales for torch._scaled_mm.""" - n, k = weight.shape - assert n % 128 == 0, f"N={n} must be divisible by 128 for MXFP8" - assert k % 128 == 0, f"K={k} must be divisible by 128 for MXFP8" - # Keep non-contiguous transpose view for column-major B expected by cuBLAS. - weight_t = weight.t() - weight_scale_cublas = _interleave_mxfp8_scales_for_cublas(weight_scale, n, k) - return weight_t, weight_scale_cublas - - -def cublas_mxfp8_blockscaled_linear( - input: torch.Tensor, - weight_t: torch.Tensor, - weight_scale_cublas: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - output_dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: - """MXFP8 dense linear via cuBLAS torch._scaled_mm.""" - if not (_is_cuda and is_sm100_supported()): - raise RuntimeError("MXFP8 cuBLAS linear requires Blackwell GPUs (SM100+).") + """Dequantize MXFP8 tensor with UE8M0 scales back to bf16. - if not hasattr(cublas_mxfp8_blockscaled_linear, "_logged"): - logger.info("Using cuBLAS MXFP8 dense linear (torch._scaled_mm)") - cublas_mxfp8_blockscaled_linear._logged = True + Applies per-group scaling: fp8_val * 2^(scale - 127). - input_2d = input.view(-1, input.shape[-1]).contiguous() - k, n = weight_t.shape - output_shape = [*input.shape[:-1], n] - m = input_2d.shape[0] - assert ( - input_2d.shape[1] == k - ), f"K mismatch: input {input_2d.shape[1]} vs weight {k}" - - if input_scale is None: - q_input, x_scale_u8 = mxfp8_group_quantize(input_2d) - else: - q_input = input_2d - x_scale_u8 = input_scale - assert x_scale_u8.dtype == torch.uint8, "MXFP8 input_scale must be UE8M0 uint8." - assert x_scale_u8.shape == (m, k // 32) - - if output_dtype is None: - if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32): - output_dtype = input_2d.dtype - else: - output_dtype = torch.bfloat16 - - if m % 128 != 0: - m_padded = ceil_div(m, 128) * 128 - pad_rows = m_padded - m - q_input = torch.cat( - [ - q_input, - torch.zeros((pad_rows, k), device=q_input.device, dtype=q_input.dtype), - ], - dim=0, - ) - x_scale_u8 = torch.cat( - [ - x_scale_u8, - torch.full( - (pad_rows, k // 32), - 127, - device=x_scale_u8.device, - dtype=x_scale_u8.dtype, - ), - ], - dim=0, - ) - else: - m_padded = m + Args: + data: FP8 tensor of shape (M, K). + scale_u8: uint8 UE8M0 scales of shape (M, K // group_size). + group_size: Number of elements per scale group (default 32). - sa = _interleave_mxfp8_scales_for_cublas(x_scale_u8, m_padded, k) - output = torch._scaled_mm( - q_input, - weight_t, - scale_a=sa, - scale_b=weight_scale_cublas, - out_dtype=output_dtype, + Returns: + bf16 tensor of shape (M, K). + """ + m, k = data.shape + n_groups = k // group_size + scales_f32 = torch.pow( + 2.0, scale_u8.to(dtype=torch.float32, device=data.device) - 127.0 ) - - output = output[:m, :] - if bias is not None: - output += bias - return output.view(*output_shape) + data_f32 = data.to(torch.float32).view(m, n_groups, group_size) + scales_f32 = scales_f32.view(m, n_groups, 1) + return (data_f32 * scales_f32).view(m, k).to(torch.bfloat16) def triton_mxfp8_blockscaled_linear( diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 00d270ae5773..7bd473e4626e 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -36,6 +36,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, cutlass_fp8_supported, + dequantize_mxfp8, is_blackwell_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod @@ -1964,19 +1965,6 @@ def __init__(self, quant_config: ModelOptMxfp8Config): super().__init__() self.quant_config = quant_config - @staticmethod - def _dequantize_mxfp8(data: torch.Tensor, scale_u8: torch.Tensor) -> torch.Tensor: - """Dequantize MXFP8 tensor with UE8M0 scales back to bf16.""" - group_size = 32 - m, k = data.shape - n_groups = k // group_size - scales_f32 = torch.pow( - 2.0, scale_u8.to(dtype=torch.float32, device=data.device) - 127.0 - ) - data_f32 = data.to(torch.float32).view(m, n_groups, group_size) - scales_f32 = scales_f32.view(m, n_groups, 1) - return (data_f32 * scales_f32).view(m, k).to(torch.bfloat16) - def create_weights( self, layer: torch.nn.Module, @@ -2029,7 +2017,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Checkpoint MXFP8 stores raw FP8 data and per-block UE8M0 scales. # Reconstruct bf16 values first, then requantize to flashinfer swizzled # scales so mm_mxfp8 can consume them directly. - weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) + weight_bf16 = dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) weight_q, weight_scale = mxfp8_quantize( weight_bf16, is_sf_swizzled_layout=True, backend="cuda" ) @@ -2061,7 +2049,7 @@ def apply( if isinstance(x, tuple): input_scale = x[1] if input_scale.dim() == 2: - input_2d = self._dequantize_mxfp8( + input_2d = dequantize_mxfp8( input_2d, input_scale.view(-1, input_scale.shape[-1]), ) @@ -2184,9 +2172,7 @@ def create_weights( set_weight_attrs(w2_weight_scale, extra_weight_attrs_block) @staticmethod - def _swizzle_mxfp8_scales( - weight_shape: tuple, scale: torch.Tensor - ) -> torch.Tensor: + def _swizzle_mxfp8_scales(weight_shape: tuple, scale: torch.Tensor) -> torch.Tensor: """Swizzle uint8 UE8M0 scales for CUTLASS MXFP8 MoE kernel.""" from triton_kernels.tensor import convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout @@ -2270,16 +2256,16 @@ def _ensure_cutlass_buffers(self, layer: torch.nn.Module) -> None: self.c_strides2 = torch.full( (num_experts,), hidden_size, device=device, dtype=torch.int64 ) - self.workspace = torch.empty(90000, device=device, dtype=torch.uint8) + from sglang.srt.layers.moe.cutlass_moe import CUTLASS_MOE_WORKSPACE_BYTES + + self.workspace = torch.empty( + CUTLASS_MOE_WORKSPACE_BYTES, device=device, dtype=torch.uint8 + ) self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) - self.a_scales_ptr = torch.empty( - num_experts, device=device, dtype=torch.int64 - ) - self.b_scales_ptr = torch.empty( - num_experts, device=device, dtype=torch.int64 - ) + self.a_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) + self.b_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64) self.expert_offsets = torch.empty( num_experts + 1, device=device, dtype=torch.int32 )