Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`GptOssMxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.GptOssMxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]

## Fused Experts Kernels
Expand Down
3 changes: 2 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ def _verify_quantization(self) -> None:
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
"gpt_oss_mxfp4",
"cpu_awq",
"gguf",
]
Expand All @@ -966,7 +967,7 @@ def _verify_quantization(self) -> None:
for name in quantization_methods:
method = me_quant.get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
quant_cfg, self.quantization, hf_config=self.hf_config
)
if quantization_override is not None:
# Raise error if the override is not custom (custom would
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def weight_loader(
expert_id: int,
return_success: bool = False,
) -> bool | None:
if self.quant_config and self.quant_config.get_name() == "mxfp4":
if self.quant_config and self.quant_config.get_name() == "gpt_oss_mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
return None


def select_mxfp4_moe_backend(
def select_gpt_oss_mxfp4_moe_backend(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zyongye IMO these are not gpt-oss specific. quark_moe.py is under refactor to adopt mxfp4 oracle which supports any model with mxfp4 moe. (might see dup comments from me, I sent in a wrong place earlier)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are temporarily route all mxfp4 to gpt_oss_mxfp4 for compatibility reason. We will later create another mxfp4moe that will decouple with this one. After we added that method we can route AMD related change to the new moe class.

config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Expand Down Expand Up @@ -400,7 +400,7 @@ def mxfp4_round_up_hidden_size_and_intermediate_size(
return hidden_size, intermediate_size


def convert_to_mxfp4_moe_kernel_format(
def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module,
w13_weight: torch.Tensor,
Expand All @@ -426,7 +426,10 @@ def convert_to_mxfp4_moe_kernel_format(

sf_block_size = 32 # mxfp4 block size

if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
if mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin,
)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"torchao",
"inc",
"mxfp4",
"gpt_oss_mxfp4",
"mxfp8",
"cpu_awq",
"online",
Expand Down Expand Up @@ -133,7 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
ModelOptNvFp4Config,
)
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .mxfp4 import GptOssMxfp4Config, Mxfp4Config
from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig
from .torchao import TorchAOConfig
Expand All @@ -160,6 +161,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": INCConfig,
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config,
"mxfp8": Mxfp8Config,
"cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None":
# Skip override to marlin kernels, as they are not
# batch invariant
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,22 @@ def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls,
hf_quant_cfg: dict[str, Any],
user_quant: str | None,
hf_config: Any = None,
) -> QuantizationMethods | None:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
circumstances.

Args:
hf_quant_cfg: The checkpoint's quantization config dict.
user_quant: The user-specified quantization method string.
hf_config: The HuggingFace model config object (e.g. for
model_type checks). May be None if not available.
"""
return None

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/cpu_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def from_config(cls, config: dict[str, Any]) -> "CPUAWQConfig":

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None":
quant_method = hf_quant_cfg.get("quant_method", "").lower()
if current_platform.is_cpu() and (quant_method == "awq"):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":

@classmethod
def override_quantization_method(
cls, hf_quant_cfg: dict[str, Any], user_quant: str | None
cls, hf_quant_cfg: dict[str, Any], user_quant: str | None, hf_config=None
) -> "QuantizationMethods | None":
# When user explicitly specifies --quantization gguf, override
# whatever quantization method is in the HF model config (e.g. fp8).
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str):

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None":
"""Override the `auto-round` method to `inc`."""
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def get_min_capability(cls) -> int:

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "FP8":
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def get_min_capability(cls) -> int:

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and ("NVFP4" in algo or "FP4" in algo):
Expand Down Expand Up @@ -1525,7 +1525,7 @@ def get_min_capability(cls) -> int:

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and "MXFP8" in algo:
Expand Down Expand Up @@ -2052,7 +2052,7 @@ def get_min_capability(cls) -> int:

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "MIXED_PRECISION":
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16":
Expand Down
55 changes: 48 additions & 7 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format,
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
select_gpt_oss_mxfp4_moe_backend,
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
Expand All @@ -38,6 +38,12 @@


class Mxfp4Config(QuantizationConfig):
"""Canonical base config for MXFP4 quantization.

Subclasses override get_name() and override_quantization_method() to
register themselves as the handler for a specific checkpoint format.
"""

def __init__(self, ignored_layers: list[str] | None = None):
super().__init__()
self.ignored_layers = ignored_layers
Expand All @@ -62,6 +68,8 @@ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
def get_config_filenames(cls) -> list[str]:
return []

# TODO (zyongye) This is only temporaty fallback.
# We should have `Mxfp4MoEMethod` after this migration is complete.
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
Expand All @@ -79,7 +87,7 @@ def get_quant_method(
)
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
return GptOssMxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
logger.debug_once(
"MXFP4 attention layer is not implemented. "
Expand All @@ -93,13 +101,46 @@ def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
return True


class Mxfp4MoEMethod(FusedMoEMethodBase):
class GptOssMxfp4Config(Mxfp4Config):
"""MXFP4 config for GPT-OSS checkpoints.

Checkpoints carry ``"quant_method": "mxfp4"`` in their JSON config.
override_quantization_method() maps that to the canonical internal name
so that the rest of the loading path uses "gpt_oss_mxfp4" consistently.
"""

@classmethod
def get_name(cls) -> QuantizationMethods:
return "gpt_oss_mxfp4"

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
# Match both "mxfp4" (original checkpoint value) and "gpt_oss_mxfp4"
# (already normalized by verify_and_update_model_config) so that
# explicit --quantization mxfp4 from the user doesn't cause a mismatch.
if not (
isinstance(hf_quant_cfg, dict)
and hf_quant_cfg.get("quant_method") in ("mxfp4", "gpt_oss_mxfp4")
):
return None
# Require explicit confirmation that this is a GPT-OSS model.
# Do NOT fall back to returning the override when hf_config is None,
# as that would silently claim all mxfp4 checkpoints.
model_type = getattr(hf_config, "model_type", None)
if model_type != "gpt_oss":
return None
return "gpt_oss_mxfp4"
Comment on lines +117 to +134
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The override_quantization_method logic in GptOssMxfp4Config has two critical issues:

  1. It returns "gpt_oss_mxfp4" if hf_config is None (since model_type will be None), which is too aggressive for a model-specific config. It should only return the override if it can confirm the model is "gpt_oss".
  2. It only checks for "quant_method": "mxfp4". However, GptOssForCausalLMConfig.verify_and_update_model_config (in models/config.py) normalizes this to "gpt_oss_mxfp4" before this check runs. If a user explicitly passes --quantization mxfp4, this method will return None, leading to a mismatch error in vllm/config/model.py because self.quantization ("mxfp4") won't match the normalized quant_method ("gpt_oss_mxfp4").
Suggested change
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
if not (
isinstance(hf_quant_cfg, dict)
and hf_quant_cfg.get("quant_method") == "mxfp4"
):
return None
model_type = getattr(hf_config, "model_type", None)
if model_type is not None and model_type != "gpt_oss":
return None
return "gpt_oss_mxfp4"
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
if not isinstance(hf_quant_cfg, dict):
return None
quant_method = hf_quant_cfg.get("quant_method")
if quant_method not in ("mxfp4", "gpt_oss_mxfp4"):
return None
if getattr(hf_config, "model_type", None) != "gpt_oss":
return None
return "gpt_oss_mxfp4"



class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
"""MXFP4 MoE quantization method."""

def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
self.weight_dtype = "gpt_oss_mxfp4"
self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)

self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
Expand Down Expand Up @@ -281,7 +322,7 @@ def _setup_kernel(

# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format(
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format,
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
select_gpt_oss_mxfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
Expand Down Expand Up @@ -995,7 +995,7 @@ def __init__(
self.w2_precision_config = None

if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)
elif self.ocp_mx_scheme.startswith("w_mxfp4"):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self.mxfp4_backend = Mxfp4MoeBackend.NONE
Expand Down Expand Up @@ -1300,7 +1300,7 @@ def _setup_kernel_via_oracle(self, layer: FusedMoE):

# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format(
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
Expand Down
17 changes: 17 additions & 0 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:


class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
quant_config = getattr(model_config.hf_config, "quantization_config", None)
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
model_config.hf_config.quantization_config["quant_method"] = "gpt_oss_mxfp4"

hf_text_quant_config = getattr(
model_config.hf_text_config, "quantization_config", None
)
if (
hf_text_quant_config is not None
and hf_text_quant_config.get("quant_method") == "mxfp4"
):
model_config.hf_text_config.quantization_config["quant_method"] = (
"gpt_oss_mxfp4"
)

@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
structured_outputs_config = vllm_config.structured_outputs_config
Expand Down
Loading
Loading