diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index c9dc1292d09e..e2fd5645f6a0 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -59,7 +59,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes. - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] - [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod] -- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] +- [`GptOssMxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.GptOssMxfp4MoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] ## Fused Experts Kernels diff --git a/vllm/config/model.py b/vllm/config/model.py index 1cce7f9d94cc..2b767b21a7c7 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -951,6 +951,7 @@ def _verify_quantization(self) -> None: # Ensure heavy backends are probed last to avoid unnecessary # imports during override detection (e.g., MXFP4 imports Triton) "mxfp4", + "gpt_oss_mxfp4", "cpu_awq", "gguf", ] @@ -966,7 +967,7 @@ def _verify_quantization(self) -> None: for name in quantization_methods: method = me_quant.get_quantization_config(name) quantization_override = method.override_quantization_method( - quant_cfg, self.quantization + quant_cfg, self.quantization, hf_config=self.hf_config ) if quantization_override is not None: # Raise error if the override is not custom (custom would diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c4fc1fd2557e..422f9e427620 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1063,7 +1063,7 @@ def weight_loader( expert_id: int, return_success: bool = False, ) -> bool | None: - if self.quant_config and self.quant_config.get_name() == "mxfp4": + if self.quant_config and self.quant_config.get_name() == "gpt_oss_mxfp4": # (FIXME) for gpt-oss all experts are combined if "bias" in weight_name: dim1 = loaded_weight.shape[1] diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index d4a0817e0be0..e7cfb881c1d9 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -194,7 +194,7 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: return None -def select_mxfp4_moe_backend( +def select_gpt_oss_mxfp4_moe_backend( config: FusedMoEConfig, ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]: """ @@ -400,7 +400,7 @@ def mxfp4_round_up_hidden_size_and_intermediate_size( return hidden_size, intermediate_size -def convert_to_mxfp4_moe_kernel_format( +def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( mxfp4_backend: Mxfp4MoeBackend, layer: torch.nn.Module, w13_weight: torch.Tensor, @@ -426,7 +426,10 @@ def convert_to_mxfp4_moe_kernel_format( sf_block_size = 32 # mxfp4 block size - if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): + if mxfp4_backend in ( + Mxfp4MoeBackend.MARLIN, + Mxfp4MoeBackend.BATCHED_MARLIN, + ): from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_moe_mxfp4_layer_for_marlin, ) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1ac0f9ee9cc5..6313db78a82b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -30,6 +30,7 @@ "torchao", "inc", "mxfp4", + "gpt_oss_mxfp4", "mxfp8", "cpu_awq", "online", @@ -133,7 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ModelOptNvFp4Config, ) from .moe_wna16 import MoeWNA16Config - from .mxfp4 import Mxfp4Config + from .mxfp4 import GptOssMxfp4Config, Mxfp4Config from .mxfp8 import Mxfp8Config from .online.base import OnlineQuantizationConfig from .torchao import TorchAOConfig @@ -160,6 +161,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "auto-round": INCConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, + "gpt_oss_mxfp4": GptOssMxfp4Config, "mxfp8": Mxfp8Config, "cpu_awq": CPUAWQConfig, "online": OnlineQuantizationConfig, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index be3001a7fa1d..cfad1f86faa2 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -232,7 +232,7 @@ def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> "QuantizationMethods | None": # Skip override to marlin kernels, as they are not # batch invariant diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index eedc62f7d4d5..344ddd8abd25 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -110,13 +110,22 @@ def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, + hf_quant_cfg: dict[str, Any], + user_quant: str | None, + hf_config: Any = None, ) -> QuantizationMethods | None: """ Detects if this quantization method can support a given checkpoint format by overriding the user specified quantization method -- this method should only be overwritten by subclasses in exceptional - circumstances + circumstances. + + Args: + hf_quant_cfg: The checkpoint's quantization config dict. + user_quant: The user-specified quantization method string. + hf_config: The HuggingFace model config object (e.g. for + model_type checks). May be None if not available. """ return None diff --git a/vllm/model_executor/layers/quantization/cpu_wna16.py b/vllm/model_executor/layers/quantization/cpu_wna16.py index 8ec569042d72..aea1067ff262 100644 --- a/vllm/model_executor/layers/quantization/cpu_wna16.py +++ b/vllm/model_executor/layers/quantization/cpu_wna16.py @@ -104,7 +104,7 @@ def from_config(cls, config: dict[str, Any]) -> "CPUAWQConfig": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> "QuantizationMethods | None": quant_method = hf_quant_cfg.get("quant_method", "").lower() if current_platform.is_cpu() and (quant_method == "awq"): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 2a72da26cc62..61eb6c912a11 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -84,7 +84,7 @@ def from_config(cls, config: dict[str, Any]) -> "GGUFConfig": @classmethod def override_quantization_method( - cls, hf_quant_cfg: dict[str, Any], user_quant: str | None + cls, hf_quant_cfg: dict[str, Any], user_quant: str | None, hf_config=None ) -> "QuantizationMethods | None": # When user explicitly specifies --quantization gguf, override # whatever quantization method is in the HF model config (e.g. fp8). diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index ce0dc0f4e052..1ca551d6351b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -214,7 +214,7 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> QuantizationMethods | None: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index 93be5b76130c..29d73928478d 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -453,7 +453,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str): @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> "QuantizationMethods | None": """Override the `auto-round` method to `inc`.""" is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round" diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 20965c77f0a6..0b8ad0cbc1ed 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -406,7 +406,7 @@ def get_min_capability(cls) -> int: @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and algo == "FP8": @@ -1028,7 +1028,7 @@ def get_min_capability(cls) -> int: @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and ("NVFP4" in algo or "FP4" in algo): @@ -1525,7 +1525,7 @@ def get_min_capability(cls) -> int: @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and "MXFP8" in algo: @@ -2052,7 +2052,7 @@ def get_min_capability(cls) -> int: @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and algo == "MIXED_PRECISION": diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index a327ac17bbc9..e5ef3f4c3168 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -130,7 +130,7 @@ def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant + cls, hf_quant_cfg, user_quant, hf_config=None ) -> QuantizationMethods | None: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) if can_convert and user_quant == "moe_wna16": diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index adb191b0a0fe..019bb45d65dc 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -19,11 +19,11 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( TRITON_BACKENDS, Mxfp4MoeBackend, - convert_to_mxfp4_moe_kernel_format, + convert_gpt_oss_weight_to_mxfp4_moe_kernel_format, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, - select_mxfp4_moe_backend, + select_gpt_oss_mxfp4_moe_backend, ) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods @@ -38,6 +38,12 @@ class Mxfp4Config(QuantizationConfig): + """Canonical base config for MXFP4 quantization. + + Subclasses override get_name() and override_quantization_method() to + register themselves as the handler for a specific checkpoint format. + """ + def __init__(self, ignored_layers: list[str] | None = None): super().__init__() self.ignored_layers = ignored_layers @@ -62,6 +68,8 @@ def get_supported_act_dtypes(cls) -> list[torch.dtype]: def get_config_filenames(cls) -> list[str]: return [] + # TODO (zyongye) This is only temporaty fallback. + # We should have `Mxfp4MoEMethod` after this migration is complete. def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": @@ -79,7 +87,7 @@ def get_quant_method( ) return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - return Mxfp4MoEMethod(layer.moe_config) + return GptOssMxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): logger.debug_once( "MXFP4 attention layer is not implemented. " @@ -93,13 +101,46 @@ def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: return True -class Mxfp4MoEMethod(FusedMoEMethodBase): +class GptOssMxfp4Config(Mxfp4Config): + """MXFP4 config for GPT-OSS checkpoints. + + Checkpoints carry ``"quant_method": "mxfp4"`` in their JSON config. + override_quantization_method() maps that to the canonical internal name + so that the rest of the loading path uses "gpt_oss_mxfp4" consistently. + """ + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "gpt_oss_mxfp4" + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant, hf_config=None + ) -> QuantizationMethods | None: + # Match both "mxfp4" (original checkpoint value) and "gpt_oss_mxfp4" + # (already normalized by verify_and_update_model_config) so that + # explicit --quantization mxfp4 from the user doesn't cause a mismatch. + if not ( + isinstance(hf_quant_cfg, dict) + and hf_quant_cfg.get("quant_method") in ("mxfp4", "gpt_oss_mxfp4") + ): + return None + # Require explicit confirmation that this is a GPT-OSS model. + # Do NOT fall back to returning the override when hf_config is None, + # as that would silently claim all mxfp4 checkpoints. + model_type = getattr(hf_config, "model_type", None) + if model_type != "gpt_oss": + return None + return "gpt_oss_mxfp4" + + +class GptOssMxfp4MoEMethod(FusedMoEMethodBase): """MXFP4 MoE quantization method.""" def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.weight_dtype = "mxfp4" - self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) + self.weight_dtype = "gpt_oss_mxfp4" + self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size @@ -281,7 +322,7 @@ def _setup_kernel( # Convert weights to kernel format w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( - convert_to_mxfp4_moe_kernel_format( + convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( mxfp4_backend=self.mxfp4_backend, layer=layer, w13_weight=w13, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 58ed8940b385..d4db929eaeb6 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -30,11 +30,11 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( TRITON_BACKENDS, Mxfp4MoeBackend, - convert_to_mxfp4_moe_kernel_format, + convert_gpt_oss_weight_to_mxfp4_moe_kernel_format, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, - select_mxfp4_moe_backend, + select_gpt_oss_mxfp4_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, @@ -995,7 +995,7 @@ def __init__( self.w2_precision_config = None if self.ocp_mx_scheme == "w_mxfp4": - self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) + self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe) elif self.ocp_mx_scheme.startswith("w_mxfp4"): # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes. self.mxfp4_backend = Mxfp4MoeBackend.NONE @@ -1300,7 +1300,7 @@ def _setup_kernel_via_oracle(self, layer: FusedMoE): # Convert weights to kernel format w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( - convert_to_mxfp4_moe_kernel_format( + convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( mxfp4_backend=self.mxfp4_backend, layer=layer, w13_weight=w13, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6e511f4f0a36..daa693472b72 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -108,6 +108,23 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class GptOssForCausalLMConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + quant_config = getattr(model_config.hf_config, "quantization_config", None) + if quant_config is not None and quant_config.get("quant_method") == "mxfp4": + model_config.hf_config.quantization_config["quant_method"] = "gpt_oss_mxfp4" + + hf_text_quant_config = getattr( + model_config.hf_text_config, "quantization_config", None + ) + if ( + hf_text_quant_config is not None + and hf_text_quant_config.get("quant_method") == "mxfp4" + ): + model_config.hf_text_config.quantization_config["quant_method"] = ( + "gpt_oss_mxfp4" + ) + @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: structured_outputs_config = vllm_config.structured_outputs_config diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index a9ec82974227..4e4eb581842d 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -560,6 +560,14 @@ def _load_weights_quark( pcp_rank=get_pcp_group().rank_in_group, ) + def _is_mxfp4(weight_dtype: str | None) -> bool: + """Return True for any MXFP4 weight-dtype variant. + + Covers "gpt_oss_mxfp4" (GptOssMxfp4MoEMethod) and "mxfp4" + (QuarkMoEMethod with fp4 weights) and any future variants. + """ + return weight_dtype is not None and "mxfp4" in weight_dtype + def _get_moe_weight_dtype(layer_id: int = 0) -> str | None: """Helper function to get MoE quantization weight dtype. @@ -578,7 +586,7 @@ def _get_moe_weight_dtype(layer_id: int = 0) -> str | None: moe_weight_dtype = _get_moe_weight_dtype(layer_id=0) - if moe_weight_dtype == "mxfp4": + if _is_mxfp4(moe_weight_dtype): # MXFP4 requires OCP_MX_BLOCK_SIZE alignment intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) @@ -682,7 +690,7 @@ def kv_cache_scale_loader( continue # Unified handler for mxfp4 weights and scales - elif moe_quant_method == "mxfp4" and any( + elif _is_mxfp4(moe_quant_method) and any( name.endswith(suffix) for suffix in [ ".w13_weight_scale", @@ -1116,8 +1124,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if hasattr(self.config, "quantization_config") else None ) - + # Normalize the checkpoint's quant_method to the internal name. + # Note: there are three places where "mxfp4" -> "gpt_oss_mxfp4" + # normalization occurs, each serving a different data path: + # 1. GptOssMxfp4Config.override_quantization_method() — sets + # ModelConfig.quantization (used to select the QuantizationConfig + # class at model init time), reading from model_arch_config which + # is a snapshot taken before verify_and_update_model_config runs. + # 2. GptOssForCausalLMConfig.verify_and_update_model_config() — + # patches hf_config.quantization_config in-place (a separate copy + # of the dict from model_arch_config) for later hf_config lookups. + # 3. Here — reads directly from self.config (the raw HF config) which + # may still carry the original "mxfp4" string from the checkpoint. if quant_method == "mxfp4": + quant_method = "gpt_oss_mxfp4" + + if quant_method == "gpt_oss_mxfp4": return self._load_weights_mxfp4( ep_rank_end, ep_rank_start, diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index 8b3ef56c80a9..b7967985f222 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -150,7 +150,7 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): if self.quant_config: quant_method_name = self.quant_config.get_name() # Check for unsupported quantization methods. - if quant_method_name == "mxfp4": + if quant_method_name in ("mxfp4", "gpt_oss_mxfp4"): raise NotImplementedError( "Transformers modeling backend does " "not support MXFP4 quantization yet."