From 60e3d1896cb92a907a9abdcbc3cdcd093ef73e48 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 20 Feb 2026 20:32:47 +0000 Subject: [PATCH 1/7] enable mxfp8 on sm120 --- python/sglang/srt/configs/model_config.py | 5 +++ .../srt/layers/quantization/auto_round.py | 18 +++++++++ .../compressed_tensors/compressed_tensors.py | 28 ++++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 38 +++++++++++++++---- .../srt/layers/quantization/fp8_kernel.py | 20 +++++++++- .../srt/layers/quantization/fp8_utils.py | 15 +++++++- python/sglang/test/test_block_fp8.py | 6 +-- 7 files changed, 116 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6fbd1db8235d..2684d303ca9c 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -877,6 +877,11 @@ def _verify_quantization(self) -> None: "petit_nvfp4": ["modelopt"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], + "mxfp8": [ + "auto-round", # Allow MXFP8 with auto-round checkpoints that have data_type="mx_fp" + "compressed-tensors", # Allow MXFP8 with compressed-tensors checkpoints that have block_structure=[1, 32] + "compressed_tensors", + ], } if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/python/sglang/srt/layers/quantization/auto_round.py b/python/sglang/srt/layers/quantization/auto_round.py index 74c4f0231ead..ecd34d32424c 100644 --- a/python/sglang/srt/layers/quantization/auto_round.py +++ b/python/sglang/srt/layers/quantization/auto_round.py @@ -95,6 +95,24 @@ def get_min_capability(cls) -> int: def get_config_filenames(cls) -> list[str]: return ["quantization_config.json"] + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + """ + Detect MXFP8 format in AutoRound checkpoints. + AutoRound can produce MXFP8 format checkpoints (data_type="mx_fp") + even though quant_method says "auto-round". + """ + if hf_quant_cfg is None: + return None + + # Check if this is actually MXFP8 format + data_type = hf_quant_cfg.get("data_type", "").lower() + if data_type in ("mx_fp", "mxfp", "mxfp8"): + # This is MXFP8 format, not AutoRound INT format + return "mxfp8" + + return None + @classmethod def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": return cls( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 4cbfed6f90e7..ae1456d20707 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -324,6 +324,34 @@ def _quantization_scheme_map_from_config( def get_config_filenames(cls) -> List[str]: return [] + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + """ + Detect MXFP8 format in compressed-tensors checkpoints. + Compressed-tensors can store MXFP8 format checkpoints with block_structure=[1, 32] + and U8 scales, even though quant_method says "compressed-tensors". + """ + if hf_quant_cfg is None: + return None + + # Check if this is actually MXFP8 format by looking at block_structure + # MXFP8 uses block_structure=[1, 32] and U8 scales + try: + config_groups = hf_quant_cfg.get("config_groups", {}) + for group_name, group_config in config_groups.items(): + weights_config = group_config.get("weights", {}) + block_structure = weights_config.get("block_structure") + + # MXFP8 uses block_structure=[1, 32] + if block_structure == [1, 32]: + # Also check if scales are U8 (uint8) format + # This is a strong indicator of MXFP8 + return "mxfp8" + except (KeyError, TypeError, AttributeError): + pass + + return None + def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: capability_tuple = DeviceCapability(*torch.cuda.get_device_capability()) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 6f6ec68d8eb3..305df58bc955 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -77,6 +77,7 @@ is_npu, is_sm90_supported, is_sm100_supported, + is_sm120_supported, log_info_on_rank0, print_warning_once, set_weight_attrs, @@ -163,10 +164,30 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> Fp8Config: - quant_method = cls.get_from_keys(config, ["quant_method"]) + # Handle both flat and nested config formats + quant_method = "" + data_type = "" + + # Try flat format first + quant_method = cls.get_from_keys_or(config, ["quant_method"], "").lower() + data_type = cls.get_from_keys_or(config, ["data_type"], "").lower() + + # Fall back to nested format if not found + if not quant_method and not data_type: + try: + quantization_section = cls.get_from_keys(config, ["quantization"]) + quant_method = quantization_section.get("quant_method", "").lower() + data_type = quantization_section.get("data_type", "").lower() + except ValueError: + pass # No nested structure, use flat format + use_mxfp8 = "mxfp8" in quant_method + # Also detect MXFP8 from data_type (for AutoRound-converted MXFP8 checkpoints) + if not use_mxfp8 and data_type in ("mx_fp", "mxfp", "mxfp8"): + use_mxfp8 = True is_checkpoint_fp8_serialized = ("fp8" in quant_method) or use_mxfp8 - activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + # activation_scheme is optional for AutoRound-converted MXFP8 checkpoints + activation_scheme = cls.get_from_keys_or(config, ["activation_scheme"], "dynamic") ignored_layers = cls.get_from_keys_or( config, ["ignored_layers", "modules_to_not_convert"], None ) @@ -177,10 +198,11 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: layer.replace("model.", "") for layer in ignored_layers ] weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) - if use_mxfp8 and weight_block_size is not None: - logger.warning( - "MXFP8 ignoring incoming weight_block_size in config.json; it is fixed to [1, 32]." - ) + if use_mxfp8: + if weight_block_size is not None: + logger.warning( + "MXFP8 ignoring incoming weight_block_size in config.json; it is fixed to [1, 32]." + ) weight_block_size = [1, 32] return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, @@ -683,7 +705,9 @@ def __init__(self, quant_config: Fp8Config): cutlass_fp8_supported() ), "cutlass_fp8 MoE requires CUDA 12.0+ with SM90 or CUDA 12.4+ with SM89" assert self.block_quant, "cutlass_fp8 MoE requires block quantization" - assert is_sm100_supported() or is_sm90_supported() + assert ( + is_sm100_supported() or is_sm90_supported() or is_sm120_supported() + ), "cutlass_fp8 MoE requires SM90, SM100, or SM120 GPUs" @staticmethod def is_deepgemm_moe_runner_backend_enabled() -> bool: diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 4a143d724f95..6526e98414dd 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -37,6 +37,8 @@ is_cpu, is_cuda, is_hip, + is_sm100_supported, + is_sm120_supported, log_info_on_rank0, ) from sglang.srt.utils.custom_op import register_custom_op @@ -1286,9 +1288,23 @@ def mxfp8_block_scaled_matmul_triton( block_m: int = 128, block_n: int = 256, block_k: int = 128, - num_stages: int = 4, + num_stages: Optional[int] = None, ) -> torch.Tensor: - """Block-scaled matmul for MXFP8 using Triton dot_scaled.""" + """Block-scaled matmul for MXFP8 using Triton dot_scaled. + + Args: + num_stages: Number of pipeline stages. If None, automatically selects based on GPU: + - SM120 (RTX 5070 Ti): 1 (to avoid shared memory limits) + - SM100 (B200/H200): 4 (better performance with more shared memory) + """ + # Auto-select num_stages based on GPU architecture if not provided + if num_stages is None: + if is_sm120_supported(): + num_stages = 1 # SM120 has less shared memory + elif is_sm100_supported(): + num_stages = 4 # SM100 can handle more stages + else: + num_stages = 1 # Default to conservative value M, K = a.shape N, K_b = b.shape assert K == K_b diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3b9616e2798f..68a9fb8debad 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -41,6 +41,7 @@ is_hip, is_sm90_supported, is_sm100_supported, + is_sm120_supported, offloader, ) @@ -662,8 +663,8 @@ def triton_mxfp8_blockscaled_linear( bias: Optional[torch.Tensor] = None, output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - if not (_is_cuda and is_sm100_supported()): - raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100+).") + if not (_is_cuda and (is_sm100_supported() or is_sm120_supported())): + raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100/SM120).") input_2d = input.view(-1, input.shape[-1]).contiguous() output_shape = [*input.shape[:-1], weight.shape[0]] @@ -714,6 +715,15 @@ def triton_mxfp8_blockscaled_linear( a_scale_packed = _pack_mxfp8_scales(x_scale_u8) b_scale_packed = _pack_mxfp8_scales(weight_scale) + # Auto-select num_stages based on GPU architecture + # SM120 needs fewer stages due to shared memory constraints + if is_sm120_supported(): + num_stages = 1 + elif is_sm100_supported(): + num_stages = 4 + else: + num_stages = 1 # Default to conservative value + output = mxfp8_block_scaled_matmul_triton( q_input, a_scale_packed, @@ -723,6 +733,7 @@ def triton_mxfp8_blockscaled_linear( block_m=block_m, block_n=block_n, block_k=block_k, + num_stages=num_stages, ) output = output[:m, :] if bias is not None: diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index edfd98e42a9c..6b73a736acb3 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -19,7 +19,7 @@ mxfp8_group_quantize, triton_mxfp8_blockscaled_linear, ) -from sglang.srt.utils import is_sm100_supported +from sglang.srt.utils import is_sm100_supported, is_sm120_supported from sglang.test.test_utils import CustomTestCase _is_cuda = torch.cuda.is_available() and torch.version.cuda @@ -452,8 +452,8 @@ class TestMXFP8DenseLinear(CustomTestCase): def setUpClass(cls): if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA is not available") - if not is_sm100_supported(): - raise unittest.SkipTest("MXFP8 requires Blackwell (SM100+)") + if not (is_sm100_supported() or is_sm120_supported()): + raise unittest.SkipTest("MXFP8 requires Blackwell (SM100/SM120)") torch.set_default_device("cuda") def _mxfp8_dense_linear(self, M, NK, dtype, seed): From 4f93829cc1e8f0ba286c1650931af0360b160a9b Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 20 Feb 2026 21:48:30 +0000 Subject: [PATCH 2/7] cleanup, test on qwen2.5 0.5b instruct --- python/sglang/srt/configs/model_config.py | 6 +-- .../srt/layers/quantization/auto_round.py | 10 +--- .../compressed_tensors/compressed_tensors.py | 20 ++----- python/sglang/srt/layers/quantization/fp8.py | 52 ++++++++++++++----- .../srt/layers/quantization/fp8_kernel.py | 15 ++---- .../srt/layers/quantization/fp8_utils.py | 10 +--- 6 files changed, 49 insertions(+), 64 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 2684d303ca9c..303cb825416a 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -877,11 +877,7 @@ def _verify_quantization(self) -> None: "petit_nvfp4": ["modelopt"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], - "mxfp8": [ - "auto-round", # Allow MXFP8 with auto-round checkpoints that have data_type="mx_fp" - "compressed-tensors", # Allow MXFP8 with compressed-tensors checkpoints that have block_structure=[1, 32] - "compressed_tensors", - ], + "mxfp8": ["auto-round", "compressed-tensors", "compressed_tensors"], } if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/python/sglang/srt/layers/quantization/auto_round.py b/python/sglang/srt/layers/quantization/auto_round.py index ecd34d32424c..dd74b7af4c74 100644 --- a/python/sglang/srt/layers/quantization/auto_round.py +++ b/python/sglang/srt/layers/quantization/auto_round.py @@ -97,20 +97,12 @@ def get_config_filenames(cls) -> list[str]: @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - """ - Detect MXFP8 format in AutoRound checkpoints. - AutoRound can produce MXFP8 format checkpoints (data_type="mx_fp") - even though quant_method says "auto-round". - """ + """Detect MXFP8 format in AutoRound checkpoints.""" if hf_quant_cfg is None: return None - - # Check if this is actually MXFP8 format data_type = hf_quant_cfg.get("data_type", "").lower() if data_type in ("mx_fp", "mxfp", "mxfp8"): - # This is MXFP8 format, not AutoRound INT format return "mxfp8" - return None @classmethod diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index ae1456d20707..b833f9180d98 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -326,30 +326,16 @@ def get_config_filenames(cls) -> List[str]: @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - """ - Detect MXFP8 format in compressed-tensors checkpoints. - Compressed-tensors can store MXFP8 format checkpoints with block_structure=[1, 32] - and U8 scales, even though quant_method says "compressed-tensors". - """ + """Detect MXFP8 format in compressed-tensors checkpoints.""" if hf_quant_cfg is None: return None - - # Check if this is actually MXFP8 format by looking at block_structure - # MXFP8 uses block_structure=[1, 32] and U8 scales try: - config_groups = hf_quant_cfg.get("config_groups", {}) - for group_name, group_config in config_groups.items(): + for group_config in hf_quant_cfg.get("config_groups", {}).values(): weights_config = group_config.get("weights", {}) - block_structure = weights_config.get("block_structure") - - # MXFP8 uses block_structure=[1, 32] - if block_structure == [1, 32]: - # Also check if scales are U8 (uint8) format - # This is a strong indicator of MXFP8 + if weights_config.get("block_structure") == [1, 32]: return "mxfp8" except (KeyError, TypeError, AttributeError): pass - return None def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 305df58bc955..88d303405f30 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -167,11 +167,11 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: # Handle both flat and nested config formats quant_method = "" data_type = "" - + # Try flat format first quant_method = cls.get_from_keys_or(config, ["quant_method"], "").lower() data_type = cls.get_from_keys_or(config, ["data_type"], "").lower() - + # Fall back to nested format if not found if not quant_method and not data_type: try: @@ -180,14 +180,25 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: data_type = quantization_section.get("data_type", "").lower() except ValueError: pass # No nested structure, use flat format - + use_mxfp8 = "mxfp8" in quant_method - # Also detect MXFP8 from data_type (for AutoRound-converted MXFP8 checkpoints) if not use_mxfp8 and data_type in ("mx_fp", "mxfp", "mxfp8"): use_mxfp8 = True + if not use_mxfp8: + try: + for group_config in config.get("config_groups", {}).values(): + if group_config.get("weights", {}).get("block_structure") == [ + 1, + 32, + ]: + use_mxfp8 = True + break + except (KeyError, TypeError, AttributeError): + pass is_checkpoint_fp8_serialized = ("fp8" in quant_method) or use_mxfp8 - # activation_scheme is optional for AutoRound-converted MXFP8 checkpoints - activation_scheme = cls.get_from_keys_or(config, ["activation_scheme"], "dynamic") + activation_scheme = cls.get_from_keys_or( + config, ["activation_scheme"], "dynamic" + ) ignored_layers = cls.get_from_keys_or( config, ["ignored_layers", "modules_to_not_convert"], None ) @@ -200,9 +211,7 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if use_mxfp8: if weight_block_size is not None: - logger.warning( - "MXFP8 ignoring incoming weight_block_size in config.json; it is fixed to [1, 32]." - ) + logger.warning("MXFP8 ignoring weight_block_size; fixed to [1, 32]") weight_block_size = [1, 32] return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, @@ -369,10 +378,6 @@ def create_weights( assert self.quant_config.activation_scheme == "dynamic" elif hasattr(self.quant_config, "linear_activation_scheme"): assert self.quant_config.linear_activation_scheme == "dynamic" - if self.use_mxfp8 and not self.is_checkpoint_fp8_serialized: - raise ValueError( - "MXFP8 requires fp8-serialized checkpoint for linear layers." - ) scale_dtype = torch.uint8 if self.use_mxfp8 else torch.float32 scale_init = torch.zeros if scale_dtype == torch.uint8 else torch.empty scale = BlockQuantScaleParameter( @@ -414,6 +419,27 @@ def create_weights( layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) + elif self.use_mxfp8 and self.block_quant: + # For online MXFP8 quantization from BF16/FP16 checkpoints, + # create weight_scale_inv parameter to be populated during quantization + if hasattr(self.quant_config, "activation_scheme"): + assert self.quant_config.activation_scheme == "dynamic" + elif hasattr(self.quant_config, "linear_activation_scheme"): + assert self.quant_config.linear_activation_scheme == "dynamic" + block_n, block_k = self.quant_config.weight_block_size + scale = BlockQuantScaleParameter( + data=torch.zeros( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale.format_ue8m0 = True + layer.register_parameter("weight_scale_inv", scale) + layer.register_parameter("input_scale", None) def process_weights_after_loading_block_quant(self, layer: Module) -> None: # If ROCm, normalize the weights and scales to e4m3fnuz diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 6526e98414dd..1466bac6bec4 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -1291,20 +1291,13 @@ def mxfp8_block_scaled_matmul_triton( num_stages: Optional[int] = None, ) -> torch.Tensor: """Block-scaled matmul for MXFP8 using Triton dot_scaled. - + Args: - num_stages: Number of pipeline stages. If None, automatically selects based on GPU: - - SM120 (RTX 5070 Ti): 1 (to avoid shared memory limits) - - SM100 (B200/H200): 4 (better performance with more shared memory) + num_stages: Number of pipeline stages. If None, auto-selects based on GPU: + SM120: 1, SM100: 4. """ - # Auto-select num_stages based on GPU architecture if not provided if num_stages is None: - if is_sm120_supported(): - num_stages = 1 # SM120 has less shared memory - elif is_sm100_supported(): - num_stages = 4 # SM100 can handle more stages - else: - num_stages = 1 # Default to conservative value + num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1) M, K = a.shape N, K_b = b.shape assert K == K_b diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 68a9fb8debad..eb841e1edf35 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -715,15 +715,7 @@ def triton_mxfp8_blockscaled_linear( a_scale_packed = _pack_mxfp8_scales(x_scale_u8) b_scale_packed = _pack_mxfp8_scales(weight_scale) - # Auto-select num_stages based on GPU architecture - # SM120 needs fewer stages due to shared memory constraints - if is_sm120_supported(): - num_stages = 1 - elif is_sm100_supported(): - num_stages = 4 - else: - num_stages = 1 # Default to conservative value - + num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1) output = mxfp8_block_scaled_matmul_triton( q_input, a_scale_packed, From 2e60fc7eed1d0bde7aa34b4eb77a5a9d2d15a155 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 20 Feb 2026 23:33:17 +0000 Subject: [PATCH 3/7] add torch empty cache to avoid oom problem --- python/sglang/srt/layers/quantization/fp8.py | 2 +- python/sglang/test/test_block_fp8.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 88d303405f30..6c47f312b47e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -439,7 +439,7 @@ def create_weights( ) scale.format_ue8m0 = True layer.register_parameter("weight_scale_inv", scale) - layer.register_parameter("input_scale", None) + layer.register_parameter("input_scale", None) def process_weights_after_loading_block_quant(self, layer: Module) -> None: # If ROCm, normalize the weights and scales to e4m3fnuz diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index 6b73a736acb3..f7ef1cec94f6 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -457,6 +457,7 @@ def setUpClass(cls): torch.set_default_device("cuda") def _mxfp8_dense_linear(self, M, NK, dtype, seed): + torch.cuda.empty_cache() N, K = NK torch.manual_seed(seed) From 9e2fc8c15d546e086b30d671677169b495b76220 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 21 Feb 2026 00:13:06 +0000 Subject: [PATCH 4/7] revert --- python/sglang/srt/layers/quantization/fp8.py | 68 +++----------------- 1 file changed, 10 insertions(+), 58 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 6c47f312b47e..c0295b60271d 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -164,41 +164,10 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> Fp8Config: - # Handle both flat and nested config formats - quant_method = "" - data_type = "" - - # Try flat format first - quant_method = cls.get_from_keys_or(config, ["quant_method"], "").lower() - data_type = cls.get_from_keys_or(config, ["data_type"], "").lower() - - # Fall back to nested format if not found - if not quant_method and not data_type: - try: - quantization_section = cls.get_from_keys(config, ["quantization"]) - quant_method = quantization_section.get("quant_method", "").lower() - data_type = quantization_section.get("data_type", "").lower() - except ValueError: - pass # No nested structure, use flat format - + quant_method = cls.get_from_keys(config, ["quant_method"]) use_mxfp8 = "mxfp8" in quant_method - if not use_mxfp8 and data_type in ("mx_fp", "mxfp", "mxfp8"): - use_mxfp8 = True - if not use_mxfp8: - try: - for group_config in config.get("config_groups", {}).values(): - if group_config.get("weights", {}).get("block_structure") == [ - 1, - 32, - ]: - use_mxfp8 = True - break - except (KeyError, TypeError, AttributeError): - pass is_checkpoint_fp8_serialized = ("fp8" in quant_method) or use_mxfp8 - activation_scheme = cls.get_from_keys_or( - config, ["activation_scheme"], "dynamic" - ) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or( config, ["ignored_layers", "modules_to_not_convert"], None ) @@ -209,10 +178,10 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: layer.replace("model.", "") for layer in ignored_layers ] weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) - if use_mxfp8: - if weight_block_size is not None: - logger.warning("MXFP8 ignoring weight_block_size; fixed to [1, 32]") - weight_block_size = [1, 32] + if use_mxfp8 and weight_block_size is not None: + logger.warning( + "MXFP8 ignoring incoming weight_block_size in config.json; it is fixed to [1, 32]." + ) return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, @@ -378,6 +347,10 @@ def create_weights( assert self.quant_config.activation_scheme == "dynamic" elif hasattr(self.quant_config, "linear_activation_scheme"): assert self.quant_config.linear_activation_scheme == "dynamic" + if self.use_mxfp8 and not self.is_checkpoint_fp8_serialized: + raise ValueError( + "MXFP8 requires fp8-serialized checkpoint for linear layers." + ) scale_dtype = torch.uint8 if self.use_mxfp8 else torch.float32 scale_init = torch.zeros if scale_dtype == torch.uint8 else torch.empty scale = BlockQuantScaleParameter( @@ -419,27 +392,6 @@ def create_weights( layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) - elif self.use_mxfp8 and self.block_quant: - # For online MXFP8 quantization from BF16/FP16 checkpoints, - # create weight_scale_inv parameter to be populated during quantization - if hasattr(self.quant_config, "activation_scheme"): - assert self.quant_config.activation_scheme == "dynamic" - elif hasattr(self.quant_config, "linear_activation_scheme"): - assert self.quant_config.linear_activation_scheme == "dynamic" - block_n, block_k = self.quant_config.weight_block_size - scale = BlockQuantScaleParameter( - data=torch.zeros( - (output_size_per_partition + block_n - 1) // block_n, - (input_size_per_partition + block_k - 1) // block_k, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - scale.format_ue8m0 = True - layer.register_parameter("weight_scale_inv", scale) - layer.register_parameter("input_scale", None) def process_weights_after_loading_block_quant(self, layer: Module) -> None: # If ROCm, normalize the weights and scales to e4m3fnuz From a57233b4e3c9b1a95a8780103edf7d74cef4c990 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 21 Feb 2026 00:14:54 +0000 Subject: [PATCH 5/7] revert --- python/sglang/srt/configs/model_config.py | 1 - .../sglang/srt/layers/quantization/auto_round.py | 10 ---------- .../compressed_tensors/compressed_tensors.py | 14 -------------- 3 files changed, 25 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 303cb825416a..6fbd1db8235d 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -877,7 +877,6 @@ def _verify_quantization(self) -> None: "petit_nvfp4": ["modelopt"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], - "mxfp8": ["auto-round", "compressed-tensors", "compressed_tensors"], } if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/python/sglang/srt/layers/quantization/auto_round.py b/python/sglang/srt/layers/quantization/auto_round.py index dd74b7af4c74..74c4f0231ead 100644 --- a/python/sglang/srt/layers/quantization/auto_round.py +++ b/python/sglang/srt/layers/quantization/auto_round.py @@ -95,16 +95,6 @@ def get_min_capability(cls) -> int: def get_config_filenames(cls) -> list[str]: return ["quantization_config.json"] - @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - """Detect MXFP8 format in AutoRound checkpoints.""" - if hf_quant_cfg is None: - return None - data_type = hf_quant_cfg.get("data_type", "").lower() - if data_type in ("mx_fp", "mxfp", "mxfp8"): - return "mxfp8" - return None - @classmethod def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": return cls( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index b833f9180d98..4cbfed6f90e7 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -324,20 +324,6 @@ def _quantization_scheme_map_from_config( def get_config_filenames(cls) -> List[str]: return [] - @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - """Detect MXFP8 format in compressed-tensors checkpoints.""" - if hf_quant_cfg is None: - return None - try: - for group_config in hf_quant_cfg.get("config_groups", {}).values(): - weights_config = group_config.get("weights", {}) - if weights_config.get("block_structure") == [1, 32]: - return "mxfp8" - except (KeyError, TypeError, AttributeError): - pass - return None - def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: capability_tuple = DeviceCapability(*torch.cuda.get_device_capability()) From d05d9ad2bb1723d1d90d5dfab3903cbfc951c5d0 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 21 Feb 2026 00:15:32 +0000 Subject: [PATCH 6/7] revert --- python/sglang/srt/layers/quantization/fp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index c0295b60271d..071550154f13 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -182,6 +182,7 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: logger.warning( "MXFP8 ignoring incoming weight_block_size in config.json; it is fixed to [1, 32]." ) + weight_block_size = [1, 32] return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, From e37fc8a966d216572f8e18d7b025fe8f838db306 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 21 Feb 2026 00:18:06 +0000 Subject: [PATCH 7/7] revert --- python/sglang/test/test_block_fp8.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index f7ef1cec94f6..6b73a736acb3 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -457,7 +457,6 @@ def setUpClass(cls): torch.set_default_device("cuda") def _mxfp8_dense_linear(self, M, NK, dtype, seed): - torch.cuda.empty_cache() N, K = NK torch.manual_seed(seed)