diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f56edce2239d..50f8442ea2de 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1083,7 +1083,7 @@ def weight_loader( expert_id: int, return_success: bool = False, ) -> bool | None: - if self.quant_config and self.quant_config.get_name() == "mxfp4": + if self.quant_config and self.quant_config.get_name() == "gpt_oss_mxfp4": # (FIXME) for gpt-oss all experts are combined if "bias" in weight_name: dim1 = loaded_weight.shape[1] diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/gpt_oss_mxfp4.py similarity index 84% rename from vllm/model_executor/layers/fused_moe/oracle/mxfp4.py rename to vllm/model_executor/layers/fused_moe/oracle/gpt_oss_mxfp4.py index 9008bdeeca7e..7a4a69685947 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/gpt_oss_mxfp4.py @@ -43,7 +43,7 @@ ) -class Mxfp4MoeBackend(Enum): +class GptOssMxfp4MoeBackend(Enum): NONE = "None" # FlashInfer TRTLLM backends FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8" @@ -65,22 +65,22 @@ class Mxfp4MoeBackend(Enum): # Backends that share the same TRTLLM weight format TRTLLM_BACKENDS = ( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, ) TRITON_BACKENDS = ( - Mxfp4MoeBackend.TRITON, - Mxfp4MoeBackend.TRITON_UNFUSED, + GptOssMxfp4MoeBackend.TRITON, + GptOssMxfp4MoeBackend.TRITON_UNFUSED, ) def backend_to_kernel_cls( - backend: Mxfp4MoeBackend, + backend: GptOssMxfp4MoeBackend, ) -> list[type[mk.FusedMoEExperts]]: if backend in ( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, ): from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import ( TrtLlmMxfp4ExpertsModular, @@ -91,8 +91,8 @@ def backend_to_kernel_cls( return [TrtLlmMxfp4ExpertsMonolithic, TrtLlmMxfp4ExpertsModular] elif backend in ( - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, ): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, @@ -100,7 +100,7 @@ def backend_to_kernel_cls( return [FlashInferExperts] - elif backend == Mxfp4MoeBackend.TRITON: + elif backend == GptOssMxfp4MoeBackend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, OAITritonMxfp4ExpertsMonolithic, @@ -109,35 +109,35 @@ def backend_to_kernel_cls( # NOTE: prefer Monolithic > Modular, so return Monolithic first. return [OAITritonMxfp4ExpertsMonolithic, OAITritonExperts] - elif backend == Mxfp4MoeBackend.TRITON_UNFUSED: + elif backend == GptOssMxfp4MoeBackend.TRITON_UNFUSED: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( UnfusedOAITritonExperts, ) return [UnfusedOAITritonExperts] - elif backend == Mxfp4MoeBackend.MARLIN: + elif backend == GptOssMxfp4MoeBackend.MARLIN: from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( MarlinExperts, ) return [MarlinExperts] - elif backend == Mxfp4MoeBackend.BATCHED_MARLIN: + elif backend == GptOssMxfp4MoeBackend.BATCHED_MARLIN: from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( BatchedMarlinExperts, ) return [BatchedMarlinExperts] - elif backend == Mxfp4MoeBackend.CK: + elif backend == GptOssMxfp4MoeBackend.CK: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( AiterExperts, ) return [AiterExperts] - elif backend == Mxfp4MoeBackend.XPU: + elif backend == GptOssMxfp4MoeBackend.XPU: from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExpertsMXFp4 return [XPUExpertsMXFp4] @@ -146,17 +146,17 @@ def backend_to_kernel_cls( raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}") -def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: - """Map user's moe_backend string to Mxfp4MoeBackend.""" - mapping: dict[str, Mxfp4MoeBackend] = { - "flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, - "flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, - "flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - "flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, - "triton": Mxfp4MoeBackend.TRITON, - "marlin": Mxfp4MoeBackend.MARLIN, - "ck": Mxfp4MoeBackend.CK, - "xpu": Mxfp4MoeBackend.XPU, +def map_mxfp4_backend(runner_backend: str) -> GptOssMxfp4MoeBackend: + """Map user's moe_backend string to GptOssMxfp4MoeBackend.""" + mapping: dict[str, GptOssMxfp4MoeBackend] = { + "flashinfer_trtllm": GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + "flashinfer_trtllm_afp8": GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + "flashinfer_cutlass": GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + "flashinfer_cutlass_afp8": GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + "triton": GptOssMxfp4MoeBackend.TRITON, + "marlin": GptOssMxfp4MoeBackend.MARLIN, + "ck": GptOssMxfp4MoeBackend.CK, + "xpu": GptOssMxfp4MoeBackend.XPU, } if backend := mapping.get(runner_backend): return backend @@ -166,29 +166,29 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: ) -def _get_priority_backends() -> list[Mxfp4MoeBackend]: +def _get_priority_backends() -> list[GptOssMxfp4MoeBackend]: """ Get available backends in priority order based on platform and config. Only includes BF16 backends. MXFP8 backends are selected via env vars. """ _AVAILABLE_BACKENDS = [ - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, - Mxfp4MoeBackend.CK, - Mxfp4MoeBackend.TRITON, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.TRITON_UNFUSED, - Mxfp4MoeBackend.MARLIN, - Mxfp4MoeBackend.BATCHED_MARLIN, - Mxfp4MoeBackend.XPU, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + GptOssMxfp4MoeBackend.CK, + GptOssMxfp4MoeBackend.TRITON, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + GptOssMxfp4MoeBackend.TRITON_UNFUSED, + GptOssMxfp4MoeBackend.MARLIN, + GptOssMxfp4MoeBackend.BATCHED_MARLIN, + GptOssMxfp4MoeBackend.XPU, ] return _AVAILABLE_BACKENDS -def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: +def _backend_activation_key(backend: GptOssMxfp4MoeBackend) -> QuantKey | None: """Map backend to its activation key (MXFP8 or None for BF16).""" if backend in ( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, ): return kMxfp8Dynamic return None @@ -196,7 +196,7 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: def select_mxfp4_moe_backend( config: FusedMoEConfig, -) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]: +) -> tuple[GptOssMxfp4MoeBackend, type[mk.FusedMoEExperts] | None]: """ Select the primary MXFP4 MoE backend. Note: Shape-specific fallbacks may still occur at runtime. @@ -216,11 +216,12 @@ def select_mxfp4_moe_backend( raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.") if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: logger.info_once("Using Triton backend for mxfp4 lora") - return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls( - Mxfp4MoeBackend.TRITON_UNFUSED + return GptOssMxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls( + GptOssMxfp4MoeBackend.TRITON_UNFUSED )[0] logger.info_once("Using Marlin backend for mxfp4 lora") - return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0] + marlin_cls = backend_to_kernel_cls(GptOssMxfp4MoeBackend.MARLIN) + return GptOssMxfp4MoeBackend.MARLIN, marlin_cls[0] activation_format = ( mk.FusedMoEActivationFormat.BatchedExperts @@ -228,10 +229,12 @@ def select_mxfp4_moe_backend( else mk.FusedMoEActivationFormat.Standard ) - def _make_log_backend(backend: Mxfp4MoeBackend): + def _make_log_backend(backend: GptOssMxfp4MoeBackend): return f"Using '{backend.value}' Mxfp4 MoE backend." - def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str: + def _make_log_unsupported( + backend: GptOssMxfp4MoeBackend, reason: str | None + ) -> str: if reason: return ( f"Mxfp4 MoE backend '{backend.value}' does not support the " @@ -243,12 +246,12 @@ def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str: ) def _return_or_raise( - backend: Mxfp4MoeBackend, + backend: GptOssMxfp4MoeBackend, config: FusedMoEConfig, weight_key: QuantKey | None, activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, - ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]: + ) -> tuple[GptOssMxfp4MoeBackend, type[mk.FusedMoEExperts]]: reason: str | None = None for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( @@ -264,9 +267,9 @@ def _return_or_raise( requested_backend = map_mxfp4_backend(runner_backend) if ( activation_format == mk.FusedMoEActivationFormat.BatchedExperts - and requested_backend == Mxfp4MoeBackend.MARLIN + and requested_backend == GptOssMxfp4MoeBackend.MARLIN ): - requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN + requested_backend = GptOssMxfp4MoeBackend.BATCHED_MARLIN return _return_or_raise( requested_backend, config, @@ -281,12 +284,16 @@ def _return_or_raise( # Handle explicit FlashInfer MXFP4 BF16 configuration. if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: - AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16) - AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16) + AVAILABLE_BACKENDS.remove( + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16 + ) + AVAILABLE_BACKENDS.remove( + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16 + ) else: if current_platform.is_device_capability(90): return _return_or_raise( - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, config, kMxfp4Static, None, @@ -294,7 +301,7 @@ def _return_or_raise( ) if current_platform.is_device_capability_family(100): return _return_or_raise( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, config, kMxfp4Static, None, @@ -312,7 +319,7 @@ def _return_or_raise( and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 ): return _return_or_raise( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, config, kMxfp4Static, kMxfp8Dynamic, @@ -325,7 +332,7 @@ def _return_or_raise( and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS ): return _return_or_raise( - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, config, kMxfp4Static, kMxfp8Dynamic, @@ -335,7 +342,7 @@ def _return_or_raise( # Handle explicit Marlin MXFP4 configuration. if envs.is_set("VLLM_MXFP4_USE_MARLIN") and envs.VLLM_MXFP4_USE_MARLIN: return _return_or_raise( - Mxfp4MoeBackend.MARLIN, + GptOssMxfp4MoeBackend.MARLIN, config, kMxfp4Static, None, @@ -355,10 +362,10 @@ def _return_or_raise( logger.debug_once(_make_log_unsupported(backend, reason), scope="local") if current_platform.is_xpu(): - backend = Mxfp4MoeBackend.XPU + backend = GptOssMxfp4MoeBackend.XPU logger.info_once(_make_log_backend(backend)) return _return_or_raise( - Mxfp4MoeBackend.XPU, + GptOssMxfp4MoeBackend.XPU, config, kMxfp4Static, None, @@ -370,14 +377,14 @@ def _return_or_raise( "No MXFP4 MoE backend supports the deployment configuration." ) - return Mxfp4MoeBackend.NONE, None + return GptOssMxfp4MoeBackend.NONE, None def mxfp4_round_up_hidden_size_and_intermediate_size( - backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int + backend: GptOssMxfp4MoeBackend, hidden_size: int, intermediate_size: int ) -> tuple[int, int]: """Round up hidden_size and intermediate_size based on backend requirements.""" - if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): + if backend in (GptOssMxfp4MoeBackend.MARLIN, GptOssMxfp4MoeBackend.BATCHED_MARLIN): intermediate_size = round_up(intermediate_size, 128) if current_platform.is_xpu(): hidden_size = round_up(hidden_size, 128) @@ -387,8 +394,8 @@ def mxfp4_round_up_hidden_size_and_intermediate_size( intermediate_size = round_up(intermediate_size, 256) hidden_size = round_up(hidden_size, 256) elif backend in ( - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, ): intermediate_size = round_up(intermediate_size, 128) hidden_size = round_up(hidden_size, 128) @@ -401,7 +408,7 @@ def mxfp4_round_up_hidden_size_and_intermediate_size( def convert_to_mxfp4_moe_kernel_format( - mxfp4_backend: Mxfp4MoeBackend, + mxfp4_backend: GptOssMxfp4MoeBackend, layer: torch.nn.Module, w13_weight: torch.Tensor, w2_weight: torch.Tensor, @@ -426,7 +433,10 @@ def convert_to_mxfp4_moe_kernel_format( sf_block_size = 32 # mxfp4 block size - if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): + if mxfp4_backend in ( + GptOssMxfp4MoeBackend.MARLIN, + GptOssMxfp4MoeBackend.BATCHED_MARLIN, + ): from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_moe_mxfp4_layer_for_marlin, ) @@ -583,8 +593,8 @@ def swap_every_two_rows(x, axis=-1): ) elif mxfp4_backend in ( - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, ): # De-interleave and swap for w13 weight, bias, and scales w13_w = w13_weight.data @@ -606,7 +616,7 @@ def swap_every_two_rows(x, axis=-1): s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) w13_scale_swapped = torch.cat([s3, s1], dim=1) - if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: + if mxfp4_backend == GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: from flashinfer import block_scale_interleave orig_shape = w13_scale_swapped.shape @@ -630,7 +640,7 @@ def swap_every_two_rows(x, axis=-1): ) else: - assert mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16 + assert mxfp4_backend == GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16 def _interleave_mxfp4_cutlass_sm90(w): w_shape = w.shape @@ -656,7 +666,7 @@ def _interleave_mxfp4_cutlass_sm90(w): w2_bias, ) - elif mxfp4_backend == Mxfp4MoeBackend.CK: + elif mxfp4_backend == GptOssMxfp4MoeBackend.CK: from vllm._aiter_ops import rocm_aiter_ops if w13_bias is not None: @@ -752,7 +762,7 @@ def _interleave_mxfp4_cutlass_sm90(w): w13_bias, w2_bias, ) - elif mxfp4_backend == Mxfp4MoeBackend.XPU: + elif mxfp4_backend == GptOssMxfp4MoeBackend.XPU: # No additional transformation needed for XPU backend return ( w13_weight, @@ -765,12 +775,12 @@ def _interleave_mxfp4_cutlass_sm90(w): else: raise ValueError( f"Unsupported mxfp4_backend: {mxfp4_backend}: " - f"should be one of: {list(Mxfp4MoeBackend)}." + f"should be one of: {list(GptOssMxfp4MoeBackend)}." ) def make_mxfp4_moe_quant_config( - mxfp4_backend: Mxfp4MoeBackend, + mxfp4_backend: GptOssMxfp4MoeBackend, w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], w1_bias: torch.Tensor | None = None, @@ -778,8 +788,8 @@ def make_mxfp4_moe_quant_config( ) -> FusedMoEQuantConfig | None: """Create a FusedMoEQuantConfig for the given MXFP4 backend.""" if mxfp4_backend in ( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, ): return mxfp4_mxfp8_moe_quant_config( w1_bias=w1_bias, @@ -788,13 +798,13 @@ def make_mxfp4_moe_quant_config( w2_scale=w2_scale, ) elif mxfp4_backend in ( - Mxfp4MoeBackend.MARLIN, - Mxfp4MoeBackend.BATCHED_MARLIN, - Mxfp4MoeBackend.TRITON, - Mxfp4MoeBackend.TRITON_UNFUSED, - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.CK, + GptOssMxfp4MoeBackend.MARLIN, + GptOssMxfp4MoeBackend.BATCHED_MARLIN, + GptOssMxfp4MoeBackend.TRITON, + GptOssMxfp4MoeBackend.TRITON_UNFUSED, + GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + GptOssMxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + GptOssMxfp4MoeBackend.CK, ): return mxfp4_w4a16_moe_quant_config( w1_bias=w1_bias, @@ -816,7 +826,7 @@ def make_mxfp4_moe_kernel( moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, experts_cls: type[mk.FusedMoEExperts], - mxfp4_backend: Mxfp4MoeBackend, + mxfp4_backend: GptOssMxfp4MoeBackend, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, shared_experts: torch.nn.Module | None = None, ) -> mk.FusedMoEKernel: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 9aceb3be054d..bbf29060a6a3 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -30,6 +30,7 @@ "torchao", "inc", "mxfp4", + "gpt_oss_mxfp4", "mxfp8", "petit_nvfp4", "cpu_awq", @@ -117,6 +118,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .fp8 import Fp8Config from .fp_quant import FPQuantConfig from .gguf import GGUFConfig + from .gpt_oss_mxfp4 import GptOssMxfp4Config from .gptq import GPTQConfig from .gptq_marlin import GPTQMarlinConfig from .inc import INCConfig @@ -127,7 +129,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ModelOptNvFp4Config, ) from .moe_wna16 import MoeWNA16Config - from .mxfp4 import Mxfp4Config from .mxfp8 import Mxfp8Config from .petit import PetitNvFp4Config from .torchao import TorchAOConfig @@ -153,7 +154,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "torchao": TorchAOConfig, "auto-round": INCConfig, "inc": INCConfig, - "mxfp4": Mxfp4Config, + "mxfp4": GptOssMxfp4Config, + "gpt_oss_mxfp4": GptOssMxfp4Config, "mxfp8": Mxfp8Config, "petit_nvfp4": PetitNvFp4Config, "cpu_awq": CPUAWQConfig, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1b8b726d9714..7aa7270afe08 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -45,8 +45,8 @@ make_fp8_moe_quant_config, select_fp8_moe_backend, ) -from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( - Mxfp4MoeBackend, +from vllm.model_executor.layers.fused_moe.oracle.gpt_oss_mxfp4 import ( + GptOssMxfp4MoeBackend, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, ) @@ -238,7 +238,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): def __init__(self, moe): super().__init__(moe) self.group_size = 32 - self.mxfp4_backend = Mxfp4MoeBackend.MARLIN + self.mxfp4_backend = GptOssMxfp4MoeBackend.MARLIN self.experts_cls = MarlinExperts def create_weights( diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/gpt_oss_mxfp4.py similarity index 96% rename from vllm/model_executor/layers/quantization/mxfp4.py rename to vllm/model_executor/layers/quantization/gpt_oss_mxfp4.py index c69e99a68126..ec43989bd9f0 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/gpt_oss_mxfp4.py @@ -16,9 +16,9 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( +from vllm.model_executor.layers.fused_moe.oracle.gpt_oss_mxfp4 import ( TRITON_BACKENDS, - Mxfp4MoeBackend, + GptOssMxfp4MoeBackend, convert_to_mxfp4_moe_kernel_format, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, @@ -37,7 +37,7 @@ logger = init_logger(__name__) -class Mxfp4Config(QuantizationConfig): +class GptOssMxfp4Config(QuantizationConfig): def __init__(self, ignored_layers: list[str] | None = None): super().__init__() self.ignored_layers = ignored_layers @@ -52,7 +52,7 @@ def get_min_capability(cls) -> int: @classmethod def get_name(cls) -> QuantizationMethods: - return "mxfp4" + return "gpt_oss_mxfp4" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: @@ -79,7 +79,7 @@ def get_quant_method( ) return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - return Mxfp4MoEMethod(layer.moe_config) + return GptOssMxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): logger.debug_once( "MXFP4 attention layer is not implemented. " @@ -93,12 +93,12 @@ def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: return True -class Mxfp4MoEMethod(FusedMoEMethodBase): +class GptOssMxfp4MoEMethod(FusedMoEMethodBase): """MXFP4 MoE quantization method.""" def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.weight_dtype = "mxfp4" + self.weight_dtype = "gpt_oss_mxfp4" self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) self.max_capture_size = ( @@ -116,7 +116,7 @@ def __init__(self, moe: FusedMoEConfig): def skip_forward_padding(self) -> bool: # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant # so can skip the padding in the forward before applying the moe method - return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8 + return self.mxfp4_backend == GptOssMxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8 def maybe_roundup_sizes( self, @@ -333,7 +333,7 @@ def process_weights_after_loading(self, layer): w13_bias = getattr(layer, "w13_bias", None) w2_bias = getattr(layer, "w2_bias", None) - if self.mxfp4_backend == Mxfp4MoeBackend.NONE: + if self.mxfp4_backend == GptOssMxfp4MoeBackend.NONE: return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index a58ee5c44e00..926b07b95777 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -26,8 +26,8 @@ ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( - Mxfp4MoeBackend, +from vllm.model_executor.layers.fused_moe.oracle.gpt_oss_mxfp4 import ( + GptOssMxfp4MoeBackend, mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) @@ -699,12 +699,12 @@ def __init__( f"Please check that the combination is supported in OCP_MX_Scheme." ) - self.mxfp4_backend: Mxfp4MoeBackend | None = None + self.mxfp4_backend: GptOssMxfp4MoeBackend | None = None if self.ocp_mx_scheme == "w_mxfp4": self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe) elif self.ocp_mx_scheme.startswith("w_mxfp4"): # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes. - self.mxfp4_backend = Mxfp4MoeBackend.NONE + self.mxfp4_backend = GptOssMxfp4MoeBackend.NONE if self.input_quant is not None: self.static_input_scales = not self.input_quant.get("is_dynamic") @@ -739,7 +739,7 @@ def __init__( or not self.ocp_mx_scheme.startswith("w_mxfp4") ) and ( self.mxfp4_backend is None - or self.mxfp4_backend is Mxfp4MoeBackend.NONE + or self.mxfp4_backend is GptOssMxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe ) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 482056250a1e..49d46692ee03 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -575,7 +575,7 @@ def _get_moe_weight_dtype(layer_id: int = 0) -> str | None: moe_weight_dtype = _get_moe_weight_dtype(layer_id=0) - if moe_weight_dtype == "mxfp4": + if moe_weight_dtype == "gpt_oss_mxfp4": # MXFP4 requires OCP_MX_BLOCK_SIZE alignment intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) @@ -679,7 +679,7 @@ def kv_cache_scale_loader( continue # Unified handler for mxfp4 weights and scales - elif moe_quant_method == "mxfp4" and any( + elif moe_quant_method == "gpt_oss_mxfp4" and any( name.endswith(suffix) for suffix in [ ".w13_weight_scale", @@ -1114,7 +1114,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: else None ) - if quant_method == "mxfp4": + if quant_method in ("mxfp4", "gpt_oss_mxfp4"): return self._load_weights_mxfp4( ep_rank_end, ep_rank_start,