diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 384f43db479b..4900949ad780 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -22,7 +22,7 @@ from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( triton_kernel_moe_forward, ) @@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init): pc2, ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) - quant_config = FusedMoEQuantConfig.make( - w1_bias=w1_bias_tri, - w2_bias=w2_bias_tri, - w1_scale=pc1, - w2_scale=pc2, - ) + if a_dtype == "bf16" and w_dtype == "mx4": + quant_config = mxfp4_w4a16_moe_quant_config( + w1_scale=pc1, + w2_scale=pc2, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + ) + else: + raise NotImplementedError( + f"Quantization configuration for activation={a_dtype} and weight={w_dtype} " + f"has not been implemented." + ) out_triton_monolithic = triton_kernel_moe_forward( hidden_states=x_tri, diff --git a/tests/models/quantization/test_gpt_oss.py b/tests/models/quantization/test_gpt_oss.py new file mode 100644 index 000000000000..e70ccaf88b0e --- /dev/null +++ b/tests/models/quantization/test_gpt_oss.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +End-to-end accuracy test for GPT-OSS model quantization. + +Config: + Task: gsm8k_platinum + Filter: flexible-extract + n-shot: 5 + Metric: exact_match + +Run: pytest tests/models/quantization/test_gpt_oss.py +""" + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import huggingface_hub +import lm_eval +import pytest +from packaging import version + +MODEL_ACCURACIES = { + # Full quantization: attention linears and MoE linears + "amd/gpt-oss-20b-WFP8-AFP8-KVFP8": 0.89, + # MoE linears only quantization + "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8": 0.89, + # MoE linears only quantization + # "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-MXFP4-KV-FP8": 0.90, +} + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.9.0") + + +def has_huggingface_access(repo): + try: + huggingface_hub.list_repo_refs(repo) + return True + except huggingface_hub.errors.RepositoryNotFoundError: + return False + + +HF_HUB_AMD_ORG_ACCESS = all( + [has_huggingface_access(model_name) for model_name in MODEL_ACCURACIES] +) + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class EvaluationConfig: + model_name: str + + def get_model_args(self, tp_size: int): + return { + "pretrained": self.model_name, + "chat_template_args": {"reasoning_effort": "low"}, + "enable_thinking": True, + "think_end_token": "200008", + "tensor_parallel_size": tp_size, + "dtype": "auto", + "gpu_memory_utilization": 0.95, + "trust_remote_code": False, + "enable_prefix_caching": False, + "enforce_eager": False, + } + + +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not HF_HUB_AMD_ORG_ACCESS, + reason="Read access to huggingface.co/amd is required for this test.", +) +@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items()) +def test_gpt_oss_attention_quantization( + model_name: str, tp_size: int, expected_accuracy: float +): + model_args = EvaluationConfig(model_name).get_model_args(tp_size) + + extra_run_kwargs = { + "gen_kwargs": {"max_gen_toks": 8000}, + "apply_chat_template": True, + "fewshot_as_multiturn": True, + "num_fewshot": 5, + } + + lm_eval_out = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks="gsm8k_platinum", + batch_size="auto", + **extra_run_kwargs, + ) + measured_accuracy = float( + lm_eval_out["results"]["gsm8k_platinum"]["exact_match,flexible-extract"] + ) + + rtol = 0.02 + assert ( + measured_accuracy - rtol < expected_accuracy + and measured_accuracy + rtol > expected_accuracy + ), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}" diff --git a/tests/models/quantization/test_gpt_oss_attn_quantization.py b/tests/models/quantization/test_gpt_oss_attn_quantization.py deleted file mode 100644 index 780165ea2ba7..000000000000 --- a/tests/models/quantization/test_gpt_oss_attn_quantization.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test attention quantization of gpt-oss model. -The qkv_proj and o_proj in self_attention can be either quantized or excluded. - -Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`. - -""" - -import importlib -import importlib.metadata -from dataclasses import dataclass - -import huggingface_hub -import lm_eval -import pytest -from packaging import version - -MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"] - -QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( - importlib.metadata.version("amd-quark") -) >= version.parse("0.8.99") - - -def has_huggingface_access(repo): - try: - huggingface_hub.list_repo_refs(repo) - return True - except huggingface_hub.errors.RepositoryNotFoundError: - return False - - -HF_HUB_AMD_ORG_ACCESS = all( - [has_huggingface_access(model_name) for model_name in MODEL_NAMES] -) - - -@dataclass -class ModelCase: - model_id: str - tp: int - - -@dataclass -class EvaluationConfig: - model_name: str - - def get_model_args(self) -> str: - return ( - f"pretrained={self.model_name}," - "tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False" - ) - - -EXPECTED_ACCURACIES = {"arc_challenge": 0.20} - - -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") -@pytest.mark.skipif( - not HF_HUB_AMD_ORG_ACCESS, - reason="Read access to huggingface.co/amd is required for this test.", -) -@pytest.mark.parametrize("model_name", MODEL_NAMES) -@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items()) -def test_gpt_oss_attention_quantization( - model_name: str, task_name: str, expected_accuracy: float -): - measured_accuracy = lm_eval.simple_evaluate( - model="vllm", - model_args=EvaluationConfig(model_name).get_model_args(), - tasks=task_name, - batch_size="auto", - )["results"][task_name]["acc,none"] - - rtol = 0.05 - assert ( - measured_accuracy - rtol < expected_accuracy - and measured_accuracy + rtol > expected_accuracy - ), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}" diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 828e9d0f3640..b9fee1dd4a45 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -386,6 +386,10 @@ def use_mxfp4_w4a4(self) -> bool: def use_nvfp4_w4a4(self) -> bool: return self.quant_dtype == "nvfp4" + @property + def use_mxfp4_w4a8(self) -> bool: + return self._a1.dtype == "fp8" and self._w1.dtype == "mxfp4" + def config_name(self, dtype: torch.dtype) -> str | None: """ Return a string used to construct the filename that contains the @@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config( w2_scale: torch.Tensor, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: list[int] | None = None, @@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config( g1_alphas=g1_alphas, w2_scale=w2_scale, g2_alphas=g2_alphas, + w1_bias=w1_bias, + w2_bias=w2_bias, a1_scale=a1_scale, a1_gscale=a1_gscale, a2_scale=a2_scale, @@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config( w2_scale: torch.Tensor, a1_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, per_act_token_quant: bool = False, ) -> FusedMoEQuantConfig: """ @@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, per_act_token_quant=per_act_token_quant, per_out_ch_quant=False, block_shape=None, @@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config( ) +def mxfp4_w4a8_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for fp8 activations and mxfp4 weights. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc("fp8", None, a1_scale, None, None, None), + _a2=FusedMoEQuantDesc("fp8", None, a2_scale, None, None, None), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + ) + + def ocp_mx_moe_quant_config( quant_dtype: str, w1_scale: Union[torch.Tensor, "PrecisionConfig"], @@ -691,6 +723,8 @@ def nvfp4_moe_quant_config( a2_gscale: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and nvp4 weights. @@ -699,6 +733,8 @@ def nvfp4_moe_quant_config( "nvfp4", w1_scale=w1_scale, w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, a1_gscale=a1_gscale, a2_gscale=a2_gscale, g1_alphas=g1_alphas, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e0907368bbc7..63aae43c3ddf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -38,7 +38,6 @@ ) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 -from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8Dynamic128Sym, @@ -1583,6 +1582,11 @@ def _get_config_quant_dtype( return "mxfp6_e3m2" elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: return "mxfp6_e2m3" + elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}: + return torch.bfloat16 + elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}: + return torch.float8_e4m3fn + return None @@ -1617,17 +1621,10 @@ def fused_experts_impl( if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" elif ocp_mx_scheme is not None: - if ocp_mx_scheme in { - "w_mxfp4_a_mxfp4", - "w_mxfp4_a_mxfp6_e3m2", - "w_mxfp4_a_mxfp6_e2m3", - }: + if ocp_mx_scheme.startswith("w_mxfp4"): # 16bit activation and fp4x2 packed weight assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" - elif ocp_mx_scheme in { - "w_mxfp6_e3m2_a_mxfp6_e3m2", - "w_mxfp6_e2m3_a_mxfp6_e2m3", - }: + elif ocp_mx_scheme.startswith("w_mxfp6"): assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( "hidden size mismatch" ) @@ -1717,17 +1714,13 @@ def fused_experts_impl( # TODO: On platforms for which `current_platform.supports_mx()` is True # and for which we have a native OCP mx fused MOE kernel, # this dequantization step should not be done. - if ocp_mx_scheme in { - OCP_MX_Scheme.w_mxfp4_a_mxfp4, - OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, - OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, - }: + if ocp_mx_scheme.startswith("w_mxfp4"): # Weight has to be dequantized for mxfp4 emulation. w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) w1_scale = None w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) w2_scale = None - elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: + elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"): w1 = dequant_mxfp6( w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype ) @@ -1736,7 +1729,7 @@ def fused_experts_impl( w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype ) w2_scale = None - elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: + elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"): w1 = dequant_mxfp6( w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype ) @@ -1779,6 +1772,7 @@ def fused_experts_impl( quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape, + ocp_mx_scheme=ocp_mx_scheme, ) # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k @@ -1846,6 +1840,7 @@ def fused_experts_impl( quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape, + ocp_mx_scheme=ocp_mx_scheme, ) if expert_map is not None: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c3be1be8537a..f35ec87aac42 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -221,12 +221,14 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: ) +# TODO(rob): move this down to the kernel. def maybe_roundup_hidden_size( hidden_size: int, act_dtype: torch.dtype, - quant_config: QuantizationConfig | None, moe_parallel_config: FusedMoEParallelConfig, is_lora_enabled: bool, + model_type: str | None, + is_mxfp4_quant: bool, ) -> int: """ Given layer hidden size and MoE configurations, round up hidden_size @@ -235,11 +237,12 @@ def maybe_roundup_hidden_size( Args: hidden_size: Layer hidden-size act_dtype: Data type of the layer activations. - quant_config: Fused MoE quantization configuration. moe_parallel_config: Fused MoE parallelization strategy configuration. is_lora_enabled: True if the engine is enabled with LoRA. This is used in the case of mxfp4 quantization in selecting the MxFP4Backend. + model_type: for checking if gpt-oss + is_mxfp4_quant: whether the layer is quantized with mxfp4 Return: Rounded up hidden_size if rounding up is required based on the configs. @@ -254,7 +257,7 @@ def maybe_roundup_hidden_size( ) # we are padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": + if model_type == "gpt_oss" and is_mxfp4_quant: from vllm.model_executor.layers.quantization.mxfp4 import ( Mxfp4Backend, get_mxfp4_backend, @@ -398,15 +401,6 @@ def __init__( # Expert mapping used in self.load_weights self.expert_mapping = expert_mapping - # Round up hidden size if needed. - hidden_size = maybe_roundup_hidden_size( - hidden_size, - moe_in_dtype, - quant_config, - self.moe_parallel_config, - is_lora_enabled=self.vllm_config.lora_config is not None, - ) - # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: @@ -508,7 +502,6 @@ def __init__( ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." assert intermediate_size % self.tp_size == 0 - self.hidden_size = hidden_size self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -548,6 +541,24 @@ def __init__( ) self.routing_method_type: RoutingMethodType = self.router.routing_method_type + # Round up hidden size before creating moe_config. + # This way moe_config is created with the correct hidden_size from the start. + hidden_size = maybe_roundup_hidden_size( + hidden_size=hidden_size, + act_dtype=moe_in_dtype, + moe_parallel_config=self.moe_parallel_config, + is_lora_enabled=vllm_config.lora_config is not None, + model_type=( + self.vllm_config.model_config.hf_config.model_type + if self.vllm_config.model_config is not None + else None + ), + is_mxfp4_quant=( + quant_config is not None and quant_config.is_mxfp4_quant(prefix, self) + ), + ) + self.hidden_size = hidden_size + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 75873a92abdb..7d5ca876bdcc 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_e4m3_quantize, ) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + per_tensor_dequantize, +) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -241,7 +244,27 @@ def moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: list[int] | None = None, is_fp4_scale_swizzled: bool = True, + ocp_mx_scheme: str | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation + if ocp_mx_scheme is not None: + if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}: + pass # No QDQ needed for these schemes + elif ocp_mx_scheme.endswith("a_fp8"): + # Perform QDQ (quantize and dequantize) on activation for emulation + # purpose, because there is no native kernel for weight in ocp_mx_scheme + # and activation in FP8. The implementation is based on existing + # non-emulation ops. + qA, qA_scale = ops.scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=False + ) + A = per_tensor_dequantize(qA, qA_scale).to(A.dtype) + # After QDQ, we don't need further quantization + return A, None + # else: For other schemes (e.g., *_a_mxfp6_e3m2, *_a_mxfp6_e2m3), + # weights are already dequantized, and we proceed with normal + # activation quantization below. + if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index c8a8424eb5c8..a10264865073 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -168,3 +168,19 @@ def maybe_update_config(self, model_name: str): # noqa: B027 Interface to update values after config initialization. """ pass + + def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: + """ + Determine if mxfp4 quantization will be used for this config. + + This allows hidden_size rounding to happen before moe_config creation + without needing to instantiate quant_method first. + + Args: + prefix: The layer prefix/name in the model + layer: The layer module + + Returns: + True if this config uses MXFP4 quantization, False otherwise + """ + return False diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index b9dec4530565..d1c9cb6bb2cf 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -229,10 +229,15 @@ def get_quant_method( ) return None + def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: + """MXFP4 config always uses MXFP4 quantization.""" + return True + class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) + self.weight_dtype = "mxfp4" self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.marlin_input_dtype = None diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index dd6db7193235..2e75a3de564b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -320,38 +320,45 @@ def _is_static_tensor_w8a8( # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static - def _is_ocp_mx( - self, - weight_quant: dict[str, Any] | None, - input_quant: dict[str, Any] | None, + def _is_w_ocp_mx_a_x( + self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None ) -> bool: - # Confirm weights and input quantized. - if weight_quant is None or input_quant is None: + """ + This check returns True only if it is an OCP-MX weight quantization. + The activation can be any data type (e.g., FP16/BF16, FP8, or OCP-MX format). + The rationale for checking only the weight type is that + the model loading concept and process primarily concerns the weights themselves. + """ + # Confirm weights quantized. + if weight_quant is None: logger.debug( - "Quark model is not in OCP MX format: " - "weight_quant or input_quant not set" + "Quark model's weight quantization is incompatible with OCP_MX format: " + "weight_quant is not set." ) return False # Input and weight qscheme needs to be per group. - if ( - weight_quant.get("qscheme") != "per_group" - or input_quant.get("qscheme") != "per_group" - ): - logger.debug("Quark model is not in OCP MX format: not per_group") + if weight_quant.get("qscheme") != "per_group": + logger.debug( + "Quark model's weight quantization is incompatible with OCP MX format: " + "weight is not per_group." + ) return False # Input and weight group size needs to be 32. - if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: - logger.debug("Quark model is not in OCP MX format: not group_size=32") + if weight_quant.get("group_size") != 32: + logger.debug( + "Quark model's weight quantization is incompatible with OCP MX format: " + "group_size of weight is not 32." + ) return False # Activations and weight scales need to be in e8m0 format. - if ( - weight_quant.get("scale_format") != "e8m0" - or input_quant.get("scale_format") != "e8m0" - ): - logger.debug("Quark model is not in OCP MX format: not scale_format e8m0") + if weight_quant.get("scale_format") != "e8m0": + logger.debug( + "Quark model's weight quantization is incompatible with OCP MX format: " + "scale_format of weight is not e8m0." + ) return False # Input and weight dtypes need to be any of fp4, @@ -360,14 +367,31 @@ def _is_ocp_mx( "fp4", "fp6_e3m2", "fp6_e2m3", - } or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}: + }: logger.debug( - "Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3" + "Quark model's weight quantization is incompatible with OCP MX format: " + "dtype is not in {fp4, fp6_e3m2, fp6_e2m3}." ) return False return True + def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: + """ + For Quark, determine if it's OCP MXFP4 by checking config directly. + This allows hidden_size rounding to happen before moe_config creation. + """ + layer_quant_config = self._find_matched_config(prefix, layer) + weight_config = layer_quant_config.get("weight") + input_config = layer_quant_config.get("input_tensors") + + return ( + self._is_w_ocp_mx_a_x(weight_config, input_config) + and weight_config is not None + and weight_config.get("dtype") == "fp4" + and getattr(torch, "float4_e2m1fn_x2", None) is not None + ) + def _find_matched_config( self, layer_name: str, module: torch.nn.Module ) -> dict[str, Any]: @@ -441,7 +465,7 @@ def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": is_static_input_scheme=True, input_symmetric=input_config.get("symmetric"), ) - elif self._is_ocp_mx(weight_config, input_config): + elif self._is_w_ocp_mx_a_x(weight_config, input_config): return QuarkOCP_MX(weight_config, input_config) raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index fc836c56be1d..190890130e33 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -18,9 +19,15 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + mxfp4_w4a8_moe_quant_config, + mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, + get_mxfp4_backend, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) @@ -37,6 +44,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils.math_utils import round_up logger = init_logger(__name__) @@ -46,6 +54,7 @@ class QuarkMoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) + self.has_bias = self.moe.has_bias @staticmethod def get_moe_method( @@ -67,7 +76,7 @@ def get_moe_method( return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) - elif quant_config._is_ocp_mx(weight_config, input_config): + elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config): return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -86,6 +95,10 @@ def __init__( self.weight_qscheme = self.weight_quant.get("qscheme") self.input_qscheme = self.input_quant.get("qscheme") + self.weight_dtype = self.weight_quant.get("dtype", "").replace( + "fp8_e4m3", "fp8" + ) + self.input_dtype = self.input_quant.get("dtype", "").replace("fp8_e4m3", "fp8") per_tensor = ( self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor" ) @@ -121,6 +134,10 @@ def __init__( self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + self.model_type = getattr( + get_current_vllm_config().model_config.hf_config, "model_type", None + ) + def create_weights( self, layer: torch.nn.Module, @@ -166,9 +183,16 @@ def create_weights( if self.weight_qscheme == "per_tensor": # Allocate 2 scales for w1 and w3 respectively. # They are combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) + if self.model_type != "gpt_oss": + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + else: + # For gpt_oss, the w1(gate) & w3(up) are fused as one. + # Therefore, only one weight scale for each expert. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 1, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False @@ -220,6 +244,27 @@ def create_weights( layer.w13_input_scale = None layer.w2_input_scale = None + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + layer.w13_bias, layer.w2_bias = None, None + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. @@ -278,21 +323,40 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): + + # For gpt_oss, w1 and w3 are fused into a single combined + # gate_up_proj tensor with size 2*intermediate_size_per_partition + # and only one scale per expert. + # Process the entire weight tensor as one shard. + if self.model_type == "gpt_oss": + for expert_id in range(layer.local_num_experts): + # Process all 2*intermediate_size_per_partition rows at once dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], + layer.w13_weight[expert_id], + layer.w13_weight_scale[expert_id][0], ) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + layer.w13_weight[expert_id], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id] ) - start += shard_size + else: + # For non-gpt_oss, process w1 and w3 shards separately + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + start += shard_size layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) + # quark's scale is 1 dim. elif self.weight_qscheme == "per_channel": if self.act_quant_group_shape == GroupShape.PER_TOKEN: @@ -343,6 +407,8 @@ def get_fused_moe_quant_config( w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, per_act_token_quant=self.input_qscheme == "per_channel", per_out_ch_quant=self.weight_qscheme == "per_channel", ) @@ -563,7 +629,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): def __init__( self, weight_config: dict[str, Any], - input_config: dict[str, Any], + input_config: dict[str, Any] | None, moe: FusedMoEConfig, ): super().__init__(moe) @@ -571,35 +637,79 @@ def __init__( self.input_quant = input_config weight_qscheme = self.weight_quant.get("qscheme") - input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): + if not weight_qscheme == "per_group": raise ValueError( "For MX(FP4) Fused MoE layers, only per-group scales " - "for weights and activations are supported. Found " - f"{weight_qscheme}, {input_qscheme}" + f"for weights are supported. Found {weight_qscheme}." ) # noqa E501 - self.static_input_scales = not self.input_quant.get("is_dynamic") - self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") - self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") + if self.input_quant is not None: + input_quant = self.input_quant["dtype"] + if input_quant in ["fp4", "fp6_e3m2", "fp6_e2m3"]: + self.input_dtype = input_quant.replace("fp", "mxfp") + elif input_quant == "fp8_e4m3": + self.input_dtype = input_quant.replace("fp8_e4m3", "fp8") + else: + raise NotImplementedError( + f"Current input dtype {input_quant} is not compatible \ + with OCP MX (weight) MoE quantization. Please open an issue" + ) + else: + self.input_dtype = None + self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( self.input_dtype, self.weight_dtype ) - if self.static_input_scales: + if self.ocp_mx_scheme is None: + raise ValueError( + f"Unsupported OCP MX dtype combination for MoE: " + f"input_dtype={self.input_dtype}, weight_dtype={self.weight_dtype}. " + f"Please check that the combination is supported in OCP_MX_Scheme." + ) + + self.mxfp4_backend: Mxfp4Backend | None = None + if self.ocp_mx_scheme == "w_mxfp4": + self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) + + if self.input_quant is not None: + self.static_input_scales = not self.input_quant.get("is_dynamic") + else: + self.static_input_scales = False + + if any( + self.ocp_mx_scheme.endswith(a_scheme) + for a_scheme in ["a_mxfp4", "a_mxfp6_e3m2", "a_mxfp6_e2m3"] + ): + if self.static_input_scales: + raise NotImplementedError( + "QuarkOCP_MX_MoEMethod with static input scales is currently " + f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. " + "Please open an issue." + ) + elif self.ocp_mx_scheme.endswith("a_fp8") and not self.static_input_scales: raise NotImplementedError( - "QuarkOCP_MX_MoEMethod with static input scales is currently " - "not implemented. Please open an issue." + "QuarkOCP_MX_MoEMethod with dynamic input scales is currently " + f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. " + "Please open an issue." ) self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() - self.emulate = not current_platform.supports_mx() or not ( - self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + self.model_type = getattr( + get_current_vllm_config().model_config.hf_config, "model_type", None ) + + self._emulate = ( + not current_platform.supports_mx() + or not self.ocp_mx_scheme.startswith("w_mxfp4") + ) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe) + + self.emulate = True if self.model_type == "gpt_oss" else self._emulate + if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " @@ -640,12 +750,23 @@ def create_weights( ) params_dtype = torch.uint8 + if self.model_type == "gpt_oss": + if current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256 + ) + else: + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 64 + ) + else: + intermediate_size_per_partition_after_pad = intermediate_size_per_partition # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition, + 2 * intermediate_size_per_partition_after_pad, self.get_packed_dim(hidden_size, self.weight_dtype), dtype=params_dtype, ), @@ -659,7 +780,9 @@ def create_weights( torch.empty( num_experts, hidden_size, - self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype), + self.get_packed_dim( + intermediate_size_per_partition_after_pad, self.weight_dtype + ), dtype=params_dtype, ), requires_grad=False, @@ -672,7 +795,7 @@ def create_weights( w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * intermediate_size_per_partition, + 2 * intermediate_size_per_partition_after_pad, hidden_size // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), @@ -682,7 +805,7 @@ def create_weights( torch.ones( num_experts, hidden_size, - intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, + intermediate_size_per_partition_after_pad // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), requires_grad=False, @@ -693,8 +816,96 @@ def create_weights( layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + layer.w13_bias, layer.w2_bias = None, None + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + def process_weights_after_loading(self, layer): + if self.static_input_scales: + # firstly, process activations if fp8 static input + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + if current_platform.is_fp8_fnuz(): + # Normalize the weights and scales + _, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( + torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fnuz), + torch.empty_like( + layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype + ), + layer.w13_input_scale, + ) + _, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fnuz), + torch.empty_like( + layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype + ), + layer.w2_input_scale, + ) + # Reset the parameter + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + + # secondly, process mxfp weights if self.emulate: + torch.cuda.empty_cache() return from aiter.utility.fp4_utils import e8m0_shuffle @@ -725,15 +936,40 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return ocp_mx_moe_quant_config( - quant_dtype=self.input_dtype, - weight_dtype=self.weight_dtype, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, - ) + if self.ocp_mx_scheme == "w_mxfp4": + return mxfp4_w4a16_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + ) + elif self.ocp_mx_scheme == "w_mxfp4_a_fp8": + return mxfp4_w4a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + block_shape=None, + ) + elif self.ocp_mx_scheme in ["w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"]: + raise NotImplementedError( + "Currently there is no corresponding fused moe quant config configured " + f"in vLLM for OCP MX scheme {self.ocp_mx_scheme}. Please open an issue." + ) + else: + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + a1_scale=None, + a2_scale=None, + block_shape=None, + ) def apply( self, @@ -743,24 +979,34 @@ def apply( topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if not self.emulate: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - ) + if ( + self.model_type == "gpt_oss" + and self.mxfp4_backend == Mxfp4Backend.TRITON + ): + raise NotImplementedError( + "Triton kernel implemented fused MoE for GPT_OSS model " + "in Quark(MoE) format is not integrated or provided yet." + ) - out = rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - quant_config=self.moe_quant_config, - expert_map=layer.expert_map, - ) + else: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) + + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + quant_config=self.moe_quant_config, + expert_map=layer.expert_map, + ) else: from vllm.model_executor.layers.fused_moe import fused_experts - out = fused_experts( + return fused_experts( x, layer.w13_weight, layer.w2_weight, @@ -773,5 +1019,3 @@ def apply( expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) - - return out diff --git a/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py b/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py index 7752324f41fe..a9157cbfb08b 100644 --- a/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py +++ b/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py @@ -20,26 +20,44 @@ class OCP_MX_Scheme(str, Enum): + w_mxfp4 = "w_mxfp4" w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4" w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2" w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3" + w_mxfp4_a_fp8 = "w_mxfp4_a_fp8" + w_mxfp6_e3m2 = "w_mxfp6_e3m2" w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2" + w_mxfp6_e3m2_a_fp8 = "w_mxfp6_e3m2_a_fp8" + w_mxfp6_e2m3 = "w_mxfp6_e2m3" w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3" + w_mxfp6_e2m3_a_fp8 = "w_mxfp6_e2m3_a_fp8" @classmethod def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None): - if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES: + if input_dtype not in OCP_MX_DTYPES and weight_dtype not in OCP_MX_DTYPES: return None + elif input_dtype is None and weight_dtype == "mxfp4": + return cls.w_mxfp4 + elif input_dtype is None and weight_dtype == "mxfp6_e3m2": + return cls.w_mxfp6_e3m2 + elif input_dtype is None and weight_dtype == "mxfp6_e2m3": + return cls.w_mxfp6_e2m3 elif input_dtype == "mxfp4" and weight_dtype == "mxfp4": return cls.w_mxfp4_a_mxfp4 elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4": return cls.w_mxfp4_a_mxfp6_e3m2 elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4": return cls.w_mxfp4_a_mxfp6_e2m3 + elif input_dtype == "fp8" and weight_dtype == "mxfp4": + return cls.w_mxfp4_a_fp8 elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2": return cls.w_mxfp6_e3m2_a_mxfp6_e3m2 + elif input_dtype == "fp8" and weight_dtype == "mxfp6_e3m2": + return cls.w_mxfp6_e3m2_a_fp8 elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3": return cls.w_mxfp6_e2m3_a_mxfp6_e2m3 + elif input_dtype == "fp8" and weight_dtype == "mxfp6_e2m3": + return cls.w_mxfp6_e2m3_a_fp8 else: logger.warning( "input_dtype='%s' and" diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index f62771c3697a..28c37c64b8ab 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable import torch import torch.distributed as dist @@ -25,13 +26,17 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.utils import rocm_unquantized_gemm from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -98,6 +103,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.num_attention_heads, total_num_kv_heads=self.num_key_value_heads, + bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) @@ -105,6 +111,7 @@ def __init__( self.o_proj = RowParallelLinear( input_size=self.num_attention_heads * self.head_dim, output_size=self.hidden_size, + bias=True, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) @@ -306,6 +313,19 @@ def forward( return x, aux_hidden_states return x + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, weight scales, activation scales + # (param_name, weight_name, expert_id, shard_id) + # NOTE: this is only used for quark. + return FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + num_redundant_experts=0, + ) + def _load_weights_mxfp4( self, ep_rank_end: int, @@ -318,7 +338,6 @@ def _load_weights_mxfp4( params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - mxfp4_block = 32 use_ep = self.parallel_config.enable_expert_parallel num_experts = self.config.num_local_experts @@ -333,9 +352,11 @@ def _load_weights_mxfp4( ) intermediate_size = self.config.intermediate_size - intermediate_size_block = intermediate_size // mxfp4_block + intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) - per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block + per_rank_intermediate_size = ( + per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE + ) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size @@ -370,7 +391,9 @@ def _load_weights_mxfp4( narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[ - ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block + ..., + tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end + // OCP_MX_BLOCK_SIZE, ] param = params_dict[name] @@ -495,6 +518,449 @@ def _load_weights_mxfp4( loaded_params.add(name) return loaded_params + def _load_weights_quark( + self, + ep_rank_end: int, + ep_rank_start: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + use_ep = self.parallel_config.enable_expert_parallel + num_experts = self.config.num_local_experts + + if use_ep: + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + else: + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( + tp_size=get_tensor_model_parallel_world_size(), + dp_size=get_dp_group().world_size, + dp_rank=get_dp_group().rank_in_group, + pcp_size=get_pcp_group().world_size, + pcp_rank=get_pcp_group().rank_in_group, + ) + + def _get_moe_weight_dtype(layer_id: int = 0) -> str | None: + """Helper function to get MoE quantization weight dtype. + + Args: + layer_id: Layer index to check (default 0, as all layers should + have the same quantization method) + + Returns: + Weight dtype string (e.g., "mxfp4", "fp8") or None if not available + """ + if hasattr(self.layers[layer_id].mlp.experts.quant_method, "weight_dtype"): + return self.layers[layer_id].mlp.experts.quant_method.weight_dtype + return None + + intermediate_size = self.config.intermediate_size + + moe_weight_dtype = _get_moe_weight_dtype(layer_id=0) + + if moe_weight_dtype == "mxfp4": + # MXFP4 requires OCP_MX_BLOCK_SIZE alignment + intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE + per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) + per_rank_intermediate_size = ( + per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE + ) + else: + # FP8 and other formats don't need alignment + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + + layer_id, expert_id, fused_name = None, None, None + moe_quant_method = None + if "experts" in name: + parts = name.split(".") + ids = [s for s in parts if s.isdigit()] + + # for amd-quark format that each expert is seperated + # need to extract the parameter name with experts fused. + # example model: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8 + if len(ids) == 2: + layer_id, expert_id = int(ids[0]), int(ids[-1]) + parts.pop(len(parts) - 1 - parts[::-1].index(str(expert_id))) + fused_name = ".".join(parts) + + # for openai mxfp4 format that all experts are combined + # no need to extract the parameter name with experts fused. + # models: openai/gpt-oss-20b, openai/gpt-oss-120b + elif len(ids) == 1: + layer_id, expert_id = int(ids[0]), None + fused_name = name + + else: + raise NameError( + f"Layer {name} contains more than 2 numeric indices. This is " + "an unexpected condition. Please open an issue if encountered." + ) + + moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id) + + def kv_cache_scale_loader( + quant_config: QuantizationConfig, + name: str, + params_dict: dict[str, typing.Any], + weight: torch.Tensor, + default_weight_loader: Callable[..., None], + loaded_params: set[str], + ) -> tuple[bool, set[str]]: + """ + Load KV cache output scales. + Returns: + Tuple of (bool, set): + - bool: True if KV-cache scale was loaded into loaded_params + - set: Updated set of loaded_params if True else the original set + """ + # load explicit cached KV output scale from quant_config + if quant_config is not None and ( + scale_name := quant_config.get_cache_scale(name) + ): + param = params_dict[scale_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + if weight.numel() != 1: + raise ValueError( + f"KV cache scale '{scale_name}' is expected to be a " + f"scalar, but got a tensor of shape {weight.shape}." + ) + # Ensure weight is a scalar before passing to loader. + weight_loader(param, weight.flatten()[0]) + loaded_params.add(scale_name) + return True, loaded_params + + return False, loaded_params + + load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader( + self.quant_config, + name, + params_dict, + loaded_weight, + default_weight_loader, + loaded_params, + ) + if load_kv_cache_scale_completed: + continue + + if ( + all(key in name for key in ["input_scale", "mlp.experts"]) + and expert_id is not None + ): + assert loaded_weight.numel() == 1 + expert_data = params_dict[fused_name].data[expert_id] + expert_data.copy_(loaded_weight) + loaded_params.add(fused_name) + continue + + # Unified handler for mxfp4 weights and scales + elif moe_quant_method == "mxfp4" and any( + name.endswith(suffix) + for suffix in [ + ".w13_weight_scale", + ".w2_weight_scale", + ".w13_weight", + ".w2_weight", + ] + ): + is_w13 = ".w13_" in name + is_scale = "_scale" in name + + # Reshape weight for mxfp4 if needed (not for scales) + if not is_scale and expert_id is None: + if is_w13: + if loaded_weight.dim() < 3: + raise ValueError( + f"Expected w13_weight to have at least 3 " + f"dimensions, got shape " + f"{loaded_weight.shape}" + ) + if loaded_weight.shape[0] != num_experts: + raise ValueError( + f"Expected w13_weight first dimension to be " + f"{num_experts}, got " + f"{loaded_weight.shape[0]}" + ) + loaded_weight = loaded_weight.view( + num_experts, 2 * intermediate_size, -1 + ).contiguous() + else: + if loaded_weight.dim() < 3: + raise ValueError( + f"Expected w2_weight to have at least 3 " + f"dimensions, got shape " + f"{loaded_weight.shape}" + ) + if loaded_weight.shape[0] != num_experts: + raise ValueError( + f"Expected w2_weight first dimension to be " + f"{num_experts}, got " + f"{loaded_weight.shape[0]}" + ) + loaded_weight = loaded_weight.view( + num_experts, -1, intermediate_size // 2 + ).contiguous() + + if use_ep: + sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + if is_w13: + if expert_id is None: + sliced_weight = loaded_weight[ + :, 2 * tp_rank_start : 2 * tp_rank_end, ... + ] + else: + sliced_weight = loaded_weight[ + 2 * tp_rank_start : 2 * tp_rank_end, ... + ] + else: + if is_scale: + sliced_weight = loaded_weight[ + ..., + tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end + // OCP_MX_BLOCK_SIZE, + ] + else: + sliced_weight = loaded_weight[ + ..., tp_rank_start // 2 : tp_rank_end // 2 + ] + + # NOTE(rob): because gpt-oss ckpt has "unique" structure with + # fused gate_up_proj fused on disk, we cannot use the existing + # weight loaders without added complexity, so just do the + # direct load here. + param = params_dict[fused_name] + expert_data = param.data[expert_id] + dim1 = sliced_weight.shape[0] + dim2 = sliced_weight.shape[1] + expert_data.data[:dim1, :dim2].copy_(sliced_weight) + loaded_params.add(fused_name) + continue + + elif name.endswith(".w13_weight") and moe_quant_method == "fp8": + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + if expert_id is None: + narrow_weight = loaded_weight[ + :, 2 * tp_rank_start : 2 * tp_rank_end, : + ] + else: + narrow_weight = loaded_weight[ + 2 * tp_rank_start : 2 * tp_rank_end, : + ] + + assert fused_name is not None + param = params_dict[fused_name] + + if expert_id is None: + param.data.copy_(narrow_weight) + else: + param.data[expert_id].copy_(narrow_weight) + + loaded_params.add(fused_name) + continue + + elif name.endswith(".w13_weight_scale") and moe_quant_method == "fp8": + assert fused_name is not None + param = params_dict[fused_name] + + # Check if this is per-channel or per-tensor scale + if loaded_weight.numel() > 1 and loaded_weight.dim() == 1: + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = loaded_weight[ + 2 * tp_rank_start : 2 * tp_rank_end + ] + else: + narrow_weight = loaded_weight + + if expert_id is None: + param.data.copy_(narrow_weight) + else: + param.data[expert_id].copy_(narrow_weight) + + loaded_params.add(fused_name) + continue + + elif name.endswith(".w13_input_scale") and moe_quant_method == "fp8": + assert fused_name is not None + param = params_dict[fused_name] + + if expert_id is None: + param.data.copy_(loaded_weight) + else: + param.data[expert_id].copy_(loaded_weight) + + loaded_params.add(fused_name) + continue + + elif name.endswith(".w2_weight") and moe_quant_method == "fp8": + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + if expert_id is None: + narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end] + else: + narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end] + + assert fused_name is not None + param = params_dict[fused_name] + + if expert_id is None: + param.data.copy_(narrow_weight) + else: + param.data[expert_id].copy_(narrow_weight) + + loaded_params.add(fused_name) + continue + + elif name.endswith(".w2_weight_scale") and moe_quant_method == "fp8": + assert fused_name is not None + param = params_dict[fused_name] + + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = loaded_weight + + if expert_id is None: + param.data.copy_(narrow_weight) + else: + param.data[expert_id].copy_(narrow_weight) + + loaded_params.add(fused_name) + continue + + # Unified handler for bias loading (w13_bias and w2_bias) + elif name.endswith(".w13_bias") or name.endswith(".w2_bias"): + is_w13_bias = name.endswith(".w13_bias") + + if use_ep: + sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] + else: + if is_w13_bias: + if expert_id is None: + sliced_weight = loaded_weight[ + :, 2 * tp_rank_start : 2 * tp_rank_end + ] + else: + sliced_weight = loaded_weight[ + 2 * tp_rank_start : 2 * tp_rank_end + ] + else: + sliced_weight = loaded_weight + if tp_rank != 0: + sliced_weight = sliced_weight.zero_() + + # NOTE(rob): because gpt-oss ckpt has "unique" structure with + # fused gate_up_proj fused on disk, we cannot use the existing + # weight loaders without added complexity, so just do the + # direct load here. + assert fused_name is not None + param = params_dict[fused_name] + expert_data = param.data[expert_id] + dim1 = sliced_weight.shape[0] + expert_data.data[:dim1].copy_(sliced_weight) + loaded_params.add(fused_name) + continue + + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + for mapping in expert_params_mapping: + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + param_name, weight_name, mapping_expert_id, shard_id = mapping + weight_name = ( + weight_name[:-1] if weight_name.endswith(".") else weight_name + ) + + if weight_name not in name: + continue + + param = params_dict[fused_name] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + # Use checkpoint's expert_id for quark format (when expert_id + # is extracted from weight name), otherwise use mapping's expert_id + actual_expert_id = ( + expert_id if expert_id is not None else mapping_expert_id + ) + success = weight_loader( + param, + loaded_weight, + fused_name, + shard_id=shard_id, + expert_id=actual_expert_id, + return_success=True, + ) + if success: + name = fused_name + loaded_params.add(name) + break + else: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + return loaded_params + def _load_weights_other( self, ep_rank_end: int, @@ -635,6 +1101,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if hasattr(self.config, "quantization_config") else None ) + if quant_method == "mxfp4": return self._load_weights_mxfp4( ep_rank_end, @@ -644,6 +1111,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights, stacked_params_mapping, ) + elif quant_method == "quark": + return self._load_weights_quark( + ep_rank_end, + ep_rank_start, + heads_per_rank, + head_start, + weights, + stacked_params_mapping, + ) else: return self._load_weights_other( ep_rank_end, @@ -676,6 +1152,15 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): # MoE Bias ".gate_up_proj_bias": ".w13_bias", ".down_proj_bias": ".w2_bias", + # For quark format + ".gate_up_proj.weight": ".w13_weight", + ".gate_up_proj.weight_scale": ".w13_weight_scale", + ".gate_up_proj.bias": ".w13_bias", + ".gate_up_proj.input_scale": ".w13_input_scale", + ".down_proj.weight": ".w2_weight", + ".down_proj.weight_scale": ".w2_weight_scale", + ".down_proj.bias": ".w2_bias", + ".down_proj.input_scale": ".w2_input_scale", }, ) @@ -725,18 +1210,6 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, weight scales, activation scales - # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( - self, - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_local_experts, - num_redundant_experts=0, - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self,