diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 4a8f31255ac6..d86bb6c3c864 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -40,6 +40,7 @@ kFp8Static128BlockSym, kFp8StaticChannelSym, kFp8StaticTensorSym, + kMxfp4Static, kNvfp4Static, ) from vllm.platforms import current_platform @@ -581,6 +582,7 @@ def _supports_quant_scheme( kFp8StaticChannelSym, kFp8StaticTensorSym, kNvfp4Static, + kMxfp4Static, ] return weight_key in SUPPORTED_W diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 5617156bf2fc..3c3243d2251d 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kMxfp4Static, ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -386,41 +387,32 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_current_device() -> bool: - raise NotImplementedError( - "OAITritonExperts is not yet used by an Oracle. " - "This method should not be called." + p = current_platform + return p.is_cuda_alike() and ( + p.is_device_capability(90) or p.is_device_capability_family(100) ) @staticmethod def _supports_no_act_and_mul() -> bool: - raise NotImplementedError( - "OAITritonExperts is not yet used by an Oracle. " - "This method should not be called." - ) + return False @staticmethod def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - raise NotImplementedError( - "OAITritonExperts is not yet used by an Oracle. " - "This method should not be called." - ) + SUPPORTED_W_A = [ + (kMxfp4Static, None), + ] + return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod def _supports_activation(activation: MoEActivation) -> bool: - raise NotImplementedError( - "OAITritonExperts is not yet used by an Oracle. " - "This method should not be called." - ) + raise NotImplementedError @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - raise NotImplementedError( - "OAITritonExperts is not yet used by an Oracle. " - "This method should not be called." - ) + return True def supports_expert_map(self) -> bool: return True @@ -477,6 +469,10 @@ def _make_routing_data( class OAITritonExperts(BaseOAITritonExperts): """OAI Triton-based fused MoE expert implementation.""" + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + return activation == MoEActivation.SWIGLUOAI + @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -561,6 +557,15 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): One use case for it is to inject LoRA modules on the activation and moe_sum. """ + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [ + MoEActivation.SILU, + MoEActivation.GELU, + MoEActivation.SWIGLUOAI, + MoEActivation.SWIGLUSTEP, + ] + @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 679b79ce971b..ddfecef3b09c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -52,7 +52,6 @@ QuantizationConfig, ) from vllm.platforms import current_platform -from vllm.utils.math_utils import round_up logger = init_logger(__name__) @@ -245,28 +244,6 @@ def maybe_roundup_hidden_size( hidden_size, act_dtype, moe_parallel_config ) - # we are padding globally so EP buffer allocation works - if model_type == "gpt_oss" and is_mxfp4_quant: - from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Backend, - get_mxfp4_backend, - ) - - current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled) - - if ( - current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - ): - hidden_size = round_up(hidden_size, 128) - elif ( - current_platform.is_rocm() - or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 - or current_mxfp4_backend == Mxfp4Backend.MARLIN - ): - hidden_size = round_up(hidden_size, 256) - return hidden_size diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py new file mode 100644 index 000000000000..f4fd29528b7d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -0,0 +1,867 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum +from typing import Union + +import torch +from torch.nn.parameter import Parameter + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoEConfig, +) +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + mxfp4_mxfp8_moe_quant_config, + mxfp4_w4a16_moe_quant_config, + ocp_mx_moe_quant_config, +) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + _swizzle_mxfp4, + get_padding_alignment, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kMxfp4Static, + kMxfp8Dynamic, +) +from vllm.platforms import current_platform +from vllm.utils.import_utils import has_triton_kernels +from vllm.utils.math_utils import round_up + +logger = init_logger(__name__) + +if has_triton_kernels(): + try: + from triton_kernels.matmul_ogs import PrecisionConfig + except (ImportError, AttributeError) as e: + logger.error( + "Failed to import Triton kernels. Please make sure your triton " + "version is compatible. Error: %s", + e, + ) + + +class Mxfp4MoeBackend(Enum): + # FIXME(zyongye) we temporarily treat monolithic and modular into 2 backend + # pending unifying them after https://github.com/vllm-project/vllm/pull/32564 + NONE = "None" + FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8" + FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC = ( + "FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC" + ) + FLASHINFER_CUTLASS_MXFP4_MXFP8 = "FLASHINFER_CUTLASS_MXFP4_MXFP8" + FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16" + FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC = "FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC" + FLASHINFER_CUTLASS_MXFP4_BF16 = "FLASHINFER_CUTLASS_MXFP4_BF16" + BATCHED_MARLIN = "BATCHED_MARLIN" + MARLIN = "MARLIN" + TRITON = "TRITON" + TRITON_MONOLITHIC = "TRITON_MONOLITHIC" + TRITON_UNFUSED = "TRITON_UNFUSED" + XPU = "XPU" + + +def backend_to_kernel_cls( + backend: Mxfp4MoeBackend, +) -> type[mk.FusedMoEPermuteExpertsUnpermute]: + if backend in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, + Mxfp4MoeBackend.TRITON_MONOLITHIC, + Mxfp4MoeBackend.XPU, + ): + raise NotImplementedError + elif backend in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + ): + from vllm.model_executor.layers.fused_moe.trtllm_moe import ( + TrtLlmGenExperts, + ) + + return TrtLlmGenExperts + elif backend in ( + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + ): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + return FlashInferExperts + elif backend == Mxfp4MoeBackend.TRITON: + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + OAITritonExperts, + ) + + return OAITritonExperts + elif backend == Mxfp4MoeBackend.TRITON_UNFUSED: + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + UnfusedOAITritonExperts, + ) + + return UnfusedOAITritonExperts + elif backend == Mxfp4MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + return MarlinExperts + elif backend == Mxfp4MoeBackend.BATCHED_MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + ) + + return BatchedMarlinExperts + + else: + raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}") + + +def select_mxfp4_moe_backend( + config: FusedMoEConfig, +) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: + """ + Select the primary MXFP4 MoE backend. + Note: Shape-specific fallbacks may still occur at runtime. + """ + + # If FlashInfer is not available, try either Marlin or Triton + triton_kernels_supported = ( + has_triton_kernels() + # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 + # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 + # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 + and (9, 0) <= current_platform.get_device_capability() < (11, 0) + ) + + if config.is_lora_enabled: + if not current_platform.is_cuda(): + raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.") + + if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: + logger.info_once("Using Triton backend for mxfp4 lora") + return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls( + Mxfp4MoeBackend.TRITON_UNFUSED + ) + + logger.info_once("Using Marlin backend for mxfp4 lora") + return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN) + + # FIXME(zyongye): we still need to fix kernel section + # after monolithic kernel refactor PR is merged + AVAILABLE_BACKENDS = [ + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + Mxfp4MoeBackend.MARLIN, + Mxfp4MoeBackend.BATCHED_MARLIN, + Mxfp4MoeBackend.TRITON, + Mxfp4MoeBackend.TRITON_MONOLITHIC, + Mxfp4MoeBackend.TRITON_UNFUSED, + Mxfp4MoeBackend.XPU, + ] + + # NOTE(zyongye): See similar comments in fp8.py + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if config.moe_parallel_config.use_batched_activation_format + else mk.FusedMoEActivationFormat.Standard + ) + + def _make_log_backend(backend: Mxfp4MoeBackend): + available_backend_strs = [b.value for b in AVAILABLE_BACKENDS] + return ( + f"Using {backend.value} Mxfp4 MoE backend out " + f"of potential backends: {available_backend_strs}." + ) + + def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"Mxfp4 MoE backend {backend.value} does not support the " + f"deployment configuration since {reason}." + ) + else: + return ( + f"Mxfp4 MoE backend '{backend.value}' does not support the " + "deployment configuration." + ) + + def _return_or_raise( + backend: Mxfp4MoeBackend, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) + + # Handle explicit FlashInfer MXFP4 BF16 configuration. + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): + if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: + AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16) + AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16) + AVAILABLE_BACKENDS.remove( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC + ) + else: + if current_platform.is_device_capability(90): + backend = Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16 + return _return_or_raise( + backend, + config, + kMxfp4Static, + None, + activation_format, + ) + if current_platform.is_device_capability_family(100): + # Using modular interface + # unifying them after #32564 is merged + if config.dp_size > 1 and config.use_ep: + backend = Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16 + return _return_or_raise( + backend, + config, + kMxfp4Static, + None, + activation_format, + ) + else: + backend = Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC + return backend, None + + # Handle explicit FlashInfer MXFP4 MXFP8 TRTLLM configuration. + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8"): + # same as BF16 case + if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: + AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8) + AVAILABLE_BACKENDS.remove( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC + ) + if config.dp_size > 1 and config.use_ep: + backend = Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8 + return _return_or_raise( + backend, + config, + kMxfp4Static, + kMxfp8Dynamic, + activation_format, + ) + else: + backend = Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC + return backend, None + + # Handle explicit FlashInfer MXFP4 MXFP8 CUTLASS configuration. + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS"): + if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: + AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8) + else: + backend = Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8 + return _return_or_raise( + backend, + config, + kMxfp4Static, + kMxfp8Dynamic, + activation_format, + ) + + # Handle explicit Marlin MXFP4 configuration. + if envs.is_set("VLLM_MXFP4_USE_MARLIN"): + if not envs.VLLM_MXFP4_USE_MARLIN: + AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.MARLIN) + AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.BATCHED_MARLIN) + else: + backend = Mxfp4MoeBackend.MARLIN + return _return_or_raise( + backend, + config, + kMxfp4Static, + None, + activation_format, + ) + + # FIXME(zyongye): manually select default kernels + # change to automatic after monolithic kernel PR is merged + if ( + current_platform.is_device_capability_family(100) + and Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16 in AVAILABLE_BACKENDS + ): + if config.dp_size > 1 and config.use_ep: + backend = Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16 + return _return_or_raise( + backend, config, kMxfp4Static, None, activation_format + ) + else: + backend = Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC + logger.info_once(_make_log_backend(backend)) + return backend, None + elif current_platform.has_device_capability(90): + if config.dp_size > 1 and config.use_ep: + backend = Mxfp4MoeBackend.TRITON + return _return_or_raise( + backend, + config, + kMxfp4Static, + None, + activation_format, + ) + else: + backend = Mxfp4MoeBackend.TRITON_MONOLITHIC + logger.info_once(_make_log_backend(backend)) + return backend, None + elif current_platform.has_device_capability(70): + backend = ( + Mxfp4MoeBackend.MARLIN + if activation_format == mk.FusedMoEActivationFormat.Standard + else Mxfp4MoeBackend.BATCHED_MARLIN + ) + return _return_or_raise( + backend, + config, + kMxfp4Static, + None, + activation_format, + ) + elif current_platform.is_xpu(): + backend = Mxfp4MoeBackend.XPU + logger.info_once(_make_log_backend(backend)) + return backend, None + + if current_platform.is_cuda() or current_platform.is_rocm(): + raise NotImplementedError( + "No MXFP4 MoE backend supports the deployment configuration." + ) + + return Mxfp4MoeBackend.NONE, None + + +def convert_to_mxfp4_moe_kernel_format( + mxfp4_backend: Mxfp4MoeBackend, + layer: torch.nn.Module, + w13_weight: torch.Tensor, + w2_weight: torch.Tensor, + w13_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w13_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + _cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Union[type[torch.Tensor], "PrecisionConfig"], + Union[type[torch.Tensor], "PrecisionConfig"], + type[torch.Tensor] | None, + type[torch.Tensor] | None, +]: + assert _cache_permute_indices is not None + + num_experts = w13_weight.shape[0] + intermediate_size = w13_weight.shape[1] // 2 + hidden_size = w13_weight.shape[2] * 2 + + sf_block_size = 32 # mxfp4 block size + assert w13_bias is not None and w2_bias is not None + + if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + + ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) = prepare_moe_mxfp4_layer_for_marlin( + layer, + w13=w13_weight, + w13_scale=w13_weight_scale, + w13_bias=w13_bias, + w2=w2_weight, + w2_scale=w2_weight_scale, + w2_bias=w2_bias, + ) + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + elif mxfp4_backend in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, + ): + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache + + layer.gemm1_alpha = Parameter( + torch.tensor([1.702] * num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_beta = Parameter( + torch.tensor([1.0] * num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_clamp_limit = Parameter( + torch.tensor([7.0] * num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + + w13_weight = w13_weight.data + w2_weight = w2_weight.data + w13_weight_scale = w13_weight_scale.data + w2_weight_scale = w2_weight_scale.data + w13_bias = w13_bias.data.to(torch.float32) + w2_bias = w2_bias.data.to(torch.float32) + + # Swap w1 and w3 as the definition of + # swiglu is different in the trtllm-gen + def swap_every_two_rows(x, axis=-1): + shape = x.shape + if axis < 0: + axis = len(shape) + axis + + # Create a new shape with pairs swapped along specified axis + new_shape = list(shape) + new_shape[axis] = shape[axis] // 2 + new_shape.insert(axis + 1, 2) + + # Reshape to expose pairs, swap them, and reshape back + x = x.reshape(*new_shape) + x = x.flip(axis + 1) + new_shape = list(shape) + return x.reshape(*new_shape) + + w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) + w13_weight = swap_every_two_rows(w13_weight, -2) + w13_bias = swap_every_two_rows(w13_bias, -1) + + # Do not interleave as the checkpoint is already interleaved + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxfp4_shuffled = [] + gemm1_scales_mxfp4_shuffled = [] + gemm2_weights_mxfp4_shuffled = [] + gemm2_scales_mxfp4_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(num_experts): + # w13 weight shuffling + permute_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_mxfp4_shuffled.append( + w13_weight[i] + .view(torch.uint8)[permute_indices.to(w13_weight.device)] + .contiguous() + ) + # w13 scale shuffling + permute_sf_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_mxfp4_shuffled.append( + nvfp4_block_scale_interleave( + w13_weight_scale[i] + .view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)] + .contiguous() + ) + ) + # w13 bias shuffling + permute_bias_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append( + w13_bias[i] + .clone() + .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] + .contiguous() + ) + # w2 weight shuffling + permute_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_mxfp4_shuffled.append( + w2_weight[i] + .view(torch.uint8)[permute_indices.to(w2_weight.device)] + .contiguous() + ) + # w2 scale shuffling + permute_sf_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_mxfp4_shuffled.append( + nvfp4_block_scale_interleave( + w2_weight_scale[i] + .view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)] + .contiguous() + ) + ) + # w2 bias shuffling + permute_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append( + w2_bias[i] + .clone() + .reshape(-1, 1)[permute_indices.to(w2_bias.device)] + .contiguous() + ) + + w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) + w13_weight_scale = ( + torch.stack(gemm1_scales_mxfp4_shuffled) + .reshape( + num_experts, + 2 * intermediate_size, + hidden_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) + + w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) + w2_weight_scale = ( + torch.stack(gemm2_scales_mxfp4_shuffled) + .reshape( + num_experts, + hidden_size, + intermediate_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) + w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) + w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) + + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + + elif mxfp4_backend in ( + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + ): + # De-interleave and swap for w13 weight, bias, and scales + w13_w = w13_weight.data + gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] + deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) + w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + + w13_b = w13_bias.data.to(torch.float32) + gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] + deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) + b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) + w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w13_s = w13_weight_scale.data + gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :] + deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1) + s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: + from flashinfer import block_scale_interleave + + orig_shape = w13_scale_swapped.shape + w13_scale_interleaved = block_scale_interleave( + w13_scale_swapped.view(torch.uint8) + ).reshape(orig_shape) + + w2_s = w2_weight_scale.data + orig_shape = w2_s.shape + w2_scale_interleaved = block_scale_interleave( + w2_s.view(torch.uint8) + ).reshape(orig_shape) + + w13_weight = w13_weight_swapped + w13_weight_scale = w13_scale_interleaved + w13_bias = w13_bias_swapped + w2_weight_scale = w2_scale_interleaved + + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + + elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16: + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape(w_shape[0], w_shape[1], (w_shape[2] // 4), 4) + w_interleaved = w_interleaved.permute(0, 2, 1, 3) + w_interleaved = w_interleaved.reshape( + w_shape[0], w_shape[2] // 4, w_shape[1] * 4 + ) + return w_interleaved + + w31_scales = w13_scale_swapped.to(torch.uint8) + w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales) + + w2_weight_scale = layer.w2_weight_scale.data + w2_scale = w2_weight_scale.to(torch.uint8) + w2_scale_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scale) + + w13_weight = w13_weight_swapped + w13_bias = w13_bias_swapped + w13_weight_scale = w31_scales_interleaved + w2_weight_scale = w2_scale_interleaved + + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + + elif mxfp4_backend in ( + Mxfp4MoeBackend.TRITON, + Mxfp4MoeBackend.TRITON_MONOLITHIC, + Mxfp4MoeBackend.TRITON_UNFUSED, + ): + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, + layer.w13_weight_scale, + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, + layer.w2_weight_scale, + ) + + w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) + ) + w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) + ) + + del layer.w13_weight + del layer.w2_weight + + return ( + w13_weight, + w2_weight, + w13_precision_config, + w2_precision_config, + w13_bias, + w2_bias, + ) + else: + raise ValueError( + f"Unsupported mxfp4_backend: {mxfp4_backend}: " + f"should be one of: {list(Mxfp4MoeBackend)}." + ) + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + + +def mxfp4_round_up_hidden_size_and_intermediate_size( + backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int +) -> tuple[int, int]: + if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): + # The moe marlin kernel requires that for each linear + # n % 256 == 0 and k % 128 == 0. + # In gate_up_proj: + # n = 2 * intermediate_size_per_partition_after_pad + # k = hidden_size + # In down_proj + # n = hidden_size + # k = intermediate_size_per_partition_after_pad + intermediate_size = round_up(intermediate_size, 128) + if backend == Mxfp4MoeBackend.XPU: + hidden_size = round_up(hidden_size, 128) + else: + hidden_size = round_up(hidden_size, 256) + + elif backend in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, + ): + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + # other padding to increase performance + intermediate_size = round_up(intermediate_size, 256) + hidden_size = round_up(hidden_size, 256) + elif backend in ( + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + ): + intermediate_size = round_up(intermediate_size, 128) + hidden_size = round_up(hidden_size, 128) + elif current_platform.is_rocm(): + pad_align = get_padding_alignment() + intermediate_size = round_up(intermediate_size, pad_align) + hidden_size = round_up(hidden_size, pad_align) + else: + intermediate_size = round_up(intermediate_size, 64) + return hidden_size, intermediate_size + + +def make_mxfp4_moe_quant_config( + mxfp4_backend: Mxfp4MoeBackend, + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig | None: + if mxfp4_backend in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + ): + return None + if mxfp4_backend in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + ): + return mxfp4_mxfp8_moe_quant_config( + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + elif mxfp4_backend in ( + Mxfp4MoeBackend.MARLIN, + Mxfp4MoeBackend.BATCHED_MARLIN, + Mxfp4MoeBackend.TRITON, + Mxfp4MoeBackend.TRITON_UNFUSED, + Mxfp4MoeBackend.TRITON_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, + ): + return mxfp4_w4a16_moe_quant_config( + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + else: + return ocp_mx_moe_quant_config( + quant_dtype="mxfp4", + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + + +def make_mxfp4_moe_kernel( + moe_quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], + mxfp4_backend: Mxfp4MoeBackend, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + shared_experts: torch.nn.Module | None = None, +): + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=routing_tables, + allow_new_interface=True, + ) + assert prepare_finalize is not None + + logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local") + + # Create Experts. + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: + max_num_tokens = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens is not None + experts = experts_cls( + moe_config=moe_config, + quant_config=moe_quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + ) + else: + experts = experts_cls( + moe_config=moe_config, + quant_config=moe_quant_config, + ) + + # NOTE(rob): we only want the mk to control the shared_expert + # if using all2all (for SBO). bnell is making this explict in + # the new MoE runner class. + kernel = mk.FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts=( + shared_experts + if moe_config.moe_parallel_config.use_all2all_kernels + else None + ), + moe_parallel_config=moe_config.moe_parallel_config, + inplace=( + not moe_config.disable_inplace + and mxfp4_backend + not in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, + ) + ), + ) + + return kernel diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index ee7db88ccbd0..b19a3903e98d 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, - mxfp4_w4a16_moe_quant_config, nvfp4_moe_quant_config, nvfp4_w4a16_moe_quant_config, ) @@ -383,16 +382,6 @@ def convert_to_nvfp4_moe_kernel_format( ) -def make_mxfp4_moe_quant_config( - w13_scale: torch.Tensor, - w2_scale: torch.Tensor, -) -> FusedMoEQuantConfig: - return mxfp4_w4a16_moe_quant_config( - w1_scale=w13_scale, - w2_scale=w2_scale, - ) - - def make_nvfp4_moe_quant_config( backend: NvFp4MoeBackend, w13_scale: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 2bd4cd79e031..0e2a2974ae15 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -15,7 +16,11 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kMxfp4Static, + kMxfp8Dynamic, ) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -25,7 +30,6 @@ def __init__( self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - max_capture_size, ): super().__init__(moe_config, quant_config) self.device = torch.cuda.current_device() @@ -39,7 +43,9 @@ def __init__( self.gemm1_clamp_limit = torch.tensor( [7.0] * self.num_experts, dtype=torch.float32, device=self.device ) - self.max_capture_size = max_capture_size + self.max_capture_size = ( + get_current_vllm_config().compilation_config.max_cudagraph_capture_size + ) @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -47,41 +53,35 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - raise NotImplementedError( - "TrtLlmGenExperts is not yet used by an Oracle. " - "This method should not be called." + p = current_platform + return ( + p.is_cuda() + and (p.is_device_capability_family(100)) + and has_flashinfer_trtllm_fused_moe() ) @staticmethod def _supports_no_act_and_mul() -> bool: - raise NotImplementedError( - "TrtLlmGenExperts is not yet used by an Oracle. " - "This method should not be called." - ) + return False @staticmethod def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - raise NotImplementedError( - "TrtLlmGenExperts is not yet used by an Oracle. " - "This method should not be called." - ) + SUPPORTED_W_A = [ + (kMxfp4Static, None), + (kMxfp4Static, kMxfp8Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod def _supports_activation(activation: MoEActivation) -> bool: - raise NotImplementedError( - "TrtLlmGenExperts is not yet used by an Oracle. " - "This method should not be called." - ) + return activation == MoEActivation.SWIGLUOAI @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - raise NotImplementedError( - "TrtLlmGenExperts is not yet used by an Oracle. " - "This method should not be called." - ) + return True def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 097d0bc01891..3d450e1399aa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -46,11 +46,15 @@ make_fp8_moe_quant_config, select_fp8_moe_backend, ) +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + make_mxfp4_moe_kernel, + make_mxfp4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, - make_mxfp4_moe_quant_config, make_nvfp4_moe_kernel, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, @@ -83,7 +87,7 @@ marlin_moe_permute_scales, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin, + prepare_moe_mxfp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, @@ -243,7 +247,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): def __init__(self, moe): super().__init__(moe) self.group_size = 32 - self.mxfp4_backend = NvFp4MoeBackend.MARLIN + self.mxfp4_backend = Mxfp4MoeBackend.MARLIN self.experts_cls = MarlinExperts def create_weights( @@ -318,7 +322,9 @@ def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: return make_mxfp4_moe_quant_config( - w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale + mxfp4_backend=self.mxfp4_backend, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, ) def process_weights_after_loading(self, layer: FusedMoE) -> None: @@ -332,13 +338,31 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: ) delattr(layer, "w2_weight_packed") - prepare_moe_fp4_layer_for_marlin(layer) + ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + _, + _, + ) = prepare_moe_mxfp4_layer_for_marlin( + layer, + w13=layer.w13_weight, + w2=layer.w2_weight, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, "w13_weight_scale", w13_weight_scale) + replace_parameter(layer, "w2_weight_scale", w2_weight_scale) self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config is not None: - self.moe_mk = make_nvfp4_moe_kernel( + self.moe_mk = make_mxfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, + mxfp4_backend=self.mxfp4_backend, experts_cls=self.experts_cls, shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(), diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index d81f0f80d2e7..a2bd95b4244c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,11 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum import torch -from torch.nn.parameter import Parameter -from vllm import envs from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -16,164 +13,33 @@ MoEActivation, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk -from vllm.model_executor.layers.fused_moe.all2all_utils import ( - maybe_make_prepare_finalize, -) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, - mxfp4_mxfp8_moe_quant_config, - mxfp4_w4a16_moe_quant_config, - ocp_mx_moe_quant_config, -) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - BatchedMarlinExperts, - MarlinExperts, ) -from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - OAITritonExperts, - UnfusedOAITritonExperts, +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + convert_to_mxfp4_moe_kernel_format, + make_mxfp4_moe_kernel, + make_mxfp4_moe_quant_config, + mxfp4_round_up_hidden_size_and_intermediate_size, + select_mxfp4_moe_backend, ) -from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - get_marlin_input_dtype, -) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin, -) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( _can_support_mxfp4, - _swizzle_mxfp4, - get_padding_alignment, ) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer -from vllm.utils.import_utils import has_triton_kernels -from vllm.utils.math_utils import round_up logger = init_logger(__name__) -# enum for mxfp4 backend -class Mxfp4Backend(Enum): - NONE = 0 - - # FlashInfer Backend - SM100_FI_MXFP4_MXFP8_TRTLLM = 1 - SM100_FI_MXFP4_MXFP8_CUTLASS = 2 - SM100_FI_MXFP4_BF16 = 3 - SM90_FI_MXFP4_BF16 = 4 - - # Marlin Backend - MARLIN = 5 - - # Triton Backend - TRITON = 6 - - -def get_mxfp4_backend_with_lora() -> Mxfp4Backend: - """ - Not all MXFP4 backends support LoRA. Select backends that are known to - have LoRA support. - """ - if not current_platform.is_cuda(): - return Mxfp4Backend.NONE - - # If FlashInfer is not available, try either Marlin or Triton - triton_kernels_supported = ( - has_triton_kernels() - # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 - # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 - and (9, 0) <= current_platform.get_device_capability() < (11, 0) - ) - if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: - logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") - return Mxfp4Backend.TRITON - - logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") - return Mxfp4Backend.MARLIN - - -def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: - # Backend Selection - - if with_lora_support: - return get_mxfp4_backend_with_lora() - - if current_platform.is_cuda(): - if ( - current_platform.is_device_capability(90) - and has_flashinfer() - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 - ): - logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") - return Mxfp4Backend.SM90_FI_MXFP4_BF16 - elif ( - current_platform.is_device_capability_family(100) - and has_flashinfer() - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS - ): - logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") - return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - elif ( - current_platform.is_device_capability_family(100) - and has_flashinfer() - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - ): - logger.info_once( - "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local" - ) - return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - elif current_platform.is_device_capability_family(100) and has_flashinfer(): - logger.info_once( - "Using FlashInfer MXFP4 BF16 backend for SM100, " - "For faster performance on SM100, consider setting " - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " - "accuracy." - ) - return Mxfp4Backend.SM100_FI_MXFP4_BF16 - elif ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(90) - ) and not has_flashinfer(): - logger.warning_once( - "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " - "is not available. This may result in degraded performance. " - "Please `pip install vllm[flashinfer]` for best results." - ) - - # If FlashInfer is not available, try either Marlin or Triton - triton_kernels_supported = ( - has_triton_kernels() - # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 - # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 - and (9, 0) <= current_platform.get_device_capability() < (11, 0) - ) - if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: - logger.info_once("Using Marlin backend") - return Mxfp4Backend.MARLIN - else: - logger.info_once("Using Triton backend") - return Mxfp4Backend.TRITON - elif current_platform.is_xpu(): - logger.info_once("Using xpu backend on XPU") - return Mxfp4Backend.MARLIN - elif current_platform.is_rocm() and has_triton_kernels(): - logger.info_once("Using Triton backend") - return Mxfp4Backend.TRITON - - return Mxfp4Backend.NONE - - class Mxfp4Config(QuantizationConfig): def __init__(self, ignored_layers: list[str] | None = None): super().__init__() @@ -244,21 +110,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.weight_dtype = "mxfp4" - self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) - assert self.mxfp4_backend != Mxfp4Backend.NONE, ( - f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found" - "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)." - "Please check your environment and try again." - ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} # Initialized in process_weights_after_loading for CUTLASS/SM90 backends self.moe_mk: mk.FusedMoEModularKernel | None = None + # we round up the parameter here to ensure a2a kernel allocate the correct memory size # noqa: E501 + self.moe.hidden_dim, self.moe.intermediate_size_per_partition = ( + mxfp4_round_up_hidden_size_and_intermediate_size( + self.mxfp4_backend, + self.moe.hidden_dim, + self.moe.intermediate_size_per_partition, + ) + ) + + # used for triton kernel + self.w13_precision_config = None + self.w2_precision_config = None + def create_weights( self, layer: torch.nn.Module, @@ -282,62 +156,12 @@ def create_weights( mxfp4_block = 32 - intermediate_size_per_partition_after_pad = intermediate_size_per_partition - if self.mxfp4_backend == Mxfp4Backend.MARLIN: - # The moe marlin kernel requires that for each linear - # n % 256 == 0 and k % 128 == 0. - # In gate_up_proj: - # n = 2 * intermediate_size_per_partition_after_pad - # k = hidden_size - # In down_proj - # n = hidden_size - # k = intermediate_size_per_partition_after_pad - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 128 - ) - if current_platform.is_xpu(): - hidden_size = round_up(hidden_size, 128) - else: - hidden_size = round_up(hidden_size, 256) - - layer.params_dtype = params_dtype - layer.num_experts = num_experts - layer.hidden_size = hidden_size - layer.intermediate_size_per_partition = ( - intermediate_size_per_partition_after_pad - ) - elif ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 - ): - # pad the intermediate size to be a multiple of 2 * mxfp4_block - # for to hold non-uniform sharded tensor as well as swizzling - # other padding to increase performance - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 256 - ) - hidden_size = round_up(hidden_size, 256) - elif ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - ): - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 128 - ) - hidden_size = round_up(hidden_size, 128) - elif current_platform.is_rocm(): - pad_align = get_padding_alignment() - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, pad_align - ) - hidden_size = round_up(hidden_size, pad_align) - else: - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 64 - ) + # Directly use padded size from config + self.intermediate_size = intermediate_size_per_partition_after_pad = ( + self.moe.intermediate_size_per_partition + ) + self.hidden_size = hidden_size = self.moe.hidden_dim - self.intermediate_size = intermediate_size_per_partition_after_pad - self.hidden_size = hidden_size # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( @@ -363,17 +187,6 @@ def create_weights( layer.register_parameter("w13_weight_scale", w13_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w13_bias = torch.nn.Parameter( - torch.zeros( - num_experts, - 2 * intermediate_size_per_partition_after_pad, - dtype=torch.bfloat16, - ), - requires_grad=False, - ) - layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, extra_weight_attrs) - # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.zeros( @@ -399,540 +212,193 @@ def create_weights( layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - w2_bias = torch.nn.Parameter( - torch.zeros( - num_experts, - hidden_size, - dtype=torch.bfloat16, - ), - requires_grad=False, - ) - layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, extra_weight_attrs) - - def process_weights_after_loading(self, layer): - if self.mxfp4_backend == Mxfp4Backend.MARLIN: - prepare_moe_fp4_layer_for_marlin( - layer, input_dtype=get_marlin_input_dtype() - ) - - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - assert self.moe_quant_config is not None - - prepare_finalize = maybe_make_prepare_finalize( - moe=self.moe, - quant_config=self.moe_quant_config, - routing_tables=layer._maybe_init_expert_routing_tables(), - allow_new_interface=True, - ) - assert prepare_finalize is not None - - self.moe_mk = mk.FusedMoEModularKernel( - prepare_finalize, - MarlinExperts( - self.moe, - self.moe_quant_config, + if self.moe.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16, ), - inplace=not self.moe.disable_inplace, - shared_experts=None, - ) - elif ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 - ): - from flashinfer.fp4_quantization import nvfp4_block_scale_interleave - from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache - - layer.gemm1_alpha = Parameter( - torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False, - ) - layer.gemm1_beta = Parameter( - torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False, - ) - layer.gemm1_clamp_limit = Parameter( - torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False, ) - sf_block_size = 32 # mxfp4 block size - - assert ( - layer.w13_weight.dim() == 3 - and layer.w13_weight.shape[0] == self.num_experts - and layer.w13_weight.shape[1] == self.intermediate_size * 2 - and layer.w13_weight.shape[2] == self.hidden_size // 2 - ) - assert ( - layer.w13_weight_scale.dim() == 3 - and layer.w13_weight_scale.shape[0] == self.num_experts - and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 - and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size - ) - assert ( - layer.w2_weight.dim() == 3 - and layer.w2_weight.shape[0] == self.num_experts - and layer.w2_weight.shape[1] == self.hidden_size - and layer.w2_weight.shape[2] == self.intermediate_size // 2 - ) - assert ( - layer.w2_weight_scale.dim() == 3 - and layer.w2_weight_scale.shape[1] == self.hidden_size - and layer.w2_weight_scale.shape[2] - == self.intermediate_size // sf_block_size - ) - assert ( - layer.w13_bias.dim() == 2 - and layer.w13_bias.shape[0] == self.num_experts - and layer.w13_bias.shape[1] == self.intermediate_size * 2 - ) - assert ( - layer.w2_bias.dim() == 2 - and layer.w2_bias.shape[0] == self.num_experts - and layer.w2_bias.shape[1] == self.hidden_size - ) - - w13_weight_scale = layer.w13_weight_scale.data - w2_weight_scale = layer.w2_weight_scale.data - w13_weight = layer.w13_weight.data - w2_weight = layer.w2_weight.data - w13_bias = layer.w13_bias.data.to(torch.float32) - w2_bias = layer.w2_bias.data.to(torch.float32) - - # Swap w1 and w3 as the definition of - # swiglu is different in the trtllm-gen - def swap_every_two_rows(x, axis=-1): - shape = x.shape - if axis < 0: - axis = len(shape) + axis - - # Create a new shape with pairs swapped along specified axis - new_shape = list(shape) - new_shape[axis] = shape[axis] // 2 - new_shape.insert(axis + 1, 2) - - # Reshape to expose pairs, swap them, and reshape back - x = x.reshape(*new_shape) - x = x.flip(axis + 1) - new_shape = list(shape) - return x.reshape(*new_shape) - - w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) - w13_weight = swap_every_two_rows(w13_weight, -2) - w13_bias = swap_every_two_rows(w13_bias, -1) - - # Do not interleave as the checkpoint is already interleaved - - # Shuffle weights and scaling factors for transposed mma output - gemm1_weights_mxfp4_shuffled = [] - gemm1_scales_mxfp4_shuffled = [] - gemm2_weights_mxfp4_shuffled = [] - gemm2_scales_mxfp4_shuffled = [] - gemm1_bias_shuffled = [] - gemm2_bias_shuffled = [] - epilogue_tile_m = 128 # FIXME: this depends on the kernel internals - for i in range(self.num_experts): - # w13 weight shuffling - permute_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - w13_weight[i].view(torch.uint8), - epilogue_tile_m, - ) - gemm1_weights_mxfp4_shuffled.append( - w13_weight[i] - .view(torch.uint8)[permute_indices.to(w13_weight.device)] - .contiguous() - ) - # w13 scale shuffling - permute_sf_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m, - num_elts_per_sf=16, - ) - gemm1_scales_mxfp4_shuffled.append( - nvfp4_block_scale_interleave( - w13_weight_scale[i] - .view(torch.uint8)[ - permute_sf_indices.to(w13_weight_scale.device) - ] - .contiguous() - ) - ) - # w13 bias shuffling - permute_bias_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - w13_bias[i].clone().reshape(-1, 1), - epilogue_tile_m, - ) - gemm1_bias_shuffled.append( - w13_bias[i] - .clone() - .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] - .contiguous() - ) - # w2 weight shuffling - permute_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - w2_weight[i].view(torch.uint8), - epilogue_tile_m, - ) - gemm2_weights_mxfp4_shuffled.append( - w2_weight[i] - .view(torch.uint8)[permute_indices.to(w2_weight.device)] - .contiguous() - ) - # w2 scale shuffling - permute_sf_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m, - num_elts_per_sf=16, - ) - gemm2_scales_mxfp4_shuffled.append( - nvfp4_block_scale_interleave( - w2_weight_scale[i] - .view(torch.uint8)[ - permute_sf_indices.to(w2_weight_scale.device) - ] - .contiguous() - ) - ) - # w2 bias shuffling - permute_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - w2_bias[i].clone().reshape(-1, 1), - epilogue_tile_m, - ) - gemm2_bias_shuffled.append( - w2_bias[i] - .clone() - .reshape(-1, 1)[permute_indices.to(w2_bias.device)] - .contiguous() - ) - - w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) - w13_weight_scale = ( - torch.stack(gemm1_scales_mxfp4_shuffled) - .reshape( - self.num_experts, - 2 * self.intermediate_size, - self.hidden_size // sf_block_size, - ) - .view(torch.float8_e4m3fn) - ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) - w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) - w2_weight_scale = ( - torch.stack(gemm2_scales_mxfp4_shuffled) - .reshape( - self.num_experts, - self.hidden_size, - self.intermediate_size // sf_block_size, - ) - .view(torch.float8_e4m3fn) - ) - - layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) - layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) - layer.w13_bias = Parameter( - torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), - requires_grad=False, - ) - layer.w2_bias = Parameter( - torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1), + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), requires_grad=False, ) - elif ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - ): - sf_block_size = 32 # mxfp4 block size + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) - # Common shape assertions - assert ( - layer.w13_weight.dim() == 3 - and layer.w13_weight.shape[0] == self.num_experts - and layer.w13_weight.shape[1] == self.intermediate_size * 2 - and layer.w13_weight.shape[2] == self.hidden_size // 2 - ) - assert ( - layer.w13_weight_scale.dim() == 3 - and layer.w13_weight_scale.shape[0] == self.num_experts - and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 - and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size - ) - assert ( - layer.w2_weight.dim() == 3 - and layer.w2_weight.shape[0] == self.num_experts - and layer.w2_weight.shape[1] == self.hidden_size - and layer.w2_weight.shape[2] == self.intermediate_size // 2 - ) - assert ( - layer.w2_weight_scale.dim() == 3 - and layer.w2_weight_scale.shape[1] == self.hidden_size - and layer.w2_weight_scale.shape[2] - == self.intermediate_size // sf_block_size - ) + def _setup_kernel( + self, + layer: FusedMoE, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + ) -> None: + num_experts = self.num_experts + intermediate_size = self.intermediate_size + hidden_size = self.hidden_size + sf_block_size = 32 # mxfp4 block size + + assert ( + w13.dim() == 3 + and w13.shape[0] == num_experts + and w13.shape[1] == intermediate_size * 2 + and w13.shape[2] == hidden_size // 2 + ) + assert ( + w13_scale.dim() == 3 + and w13_scale.shape[0] == num_experts + and w13_scale.shape[1] == intermediate_size * 2 + and w13_scale.shape[2] == hidden_size // sf_block_size + ) + assert ( + w2.dim() == 3 + and w2.shape[0] == num_experts + and w2.shape[1] == hidden_size + and w2.shape[2] == intermediate_size // 2 + ) + assert ( + w2_scale.dim() == 3 + and w2_scale.shape[1] == hidden_size + and w2_scale.shape[2] == intermediate_size // sf_block_size + ) + if w13_bias is not None: assert ( - layer.w13_bias.dim() == 2 - and layer.w13_bias.shape[0] == self.num_experts - and layer.w13_bias.shape[1] == self.intermediate_size * 2 + w13_bias.dim() == 2 + and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2 ) + if w2_bias is not None: assert ( - layer.w2_bias.dim() == 2 - and layer.w2_bias.shape[0] == self.num_experts - and layer.w2_bias.shape[1] == self.hidden_size + w2_bias.dim() == 2 + and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size + ) + w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( + convert_to_mxfp4_moe_kernel_format( + mxfp4_backend=self.mxfp4_backend, + layer=layer, + w13_weight=w13, + w2_weight=w2, + w13_weight_scale=w13_scale, + w2_weight_scale=w2_scale, + w13_bias=w13_bias, + w2_bias=w2_bias, + _cache_permute_indices=self._cache_permute_indices, ) + ) + # For TRITON backends, weights and scales are wrapped tensors from + # triton_kernels that don't support .detach(). The weights have already + # been deleted and precision_config has been set inside + # convert_to_mxfp4_moe_kernel_format, manually assign parameters + if self.mxfp4_backend not in ( + Mxfp4MoeBackend.TRITON, + Mxfp4MoeBackend.TRITON_MONOLITHIC, + Mxfp4MoeBackend.TRITON_UNFUSED, + ): + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) + else: + layer.w13_weight = w13 + layer.w2_weight = w2 + self.w13_precision_config = w13_scale + self.w2_precision_config = w2_scale - # De-interleave and swap for w13 weight, bias, and scales - w13_w = layer.w13_weight.data - gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] - deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) - w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) - w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) - - w13_b = layer.w13_bias.data.to(torch.float32) - gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] - deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) - b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) - w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) - - w13_s = layer.w13_weight_scale.data - gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :] - deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1) - s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) - w13_scale_swapped = torch.cat([s3, s1], dim=1) - - if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: - from flashinfer import block_scale_interleave - - orig_shape = w13_scale_swapped.shape - w13_scale_interleaved = block_scale_interleave( - w13_scale_swapped.view(torch.uint8) - ).reshape(orig_shape) - - w2_s = layer.w2_weight_scale.data - orig_shape = w2_s.shape - w2_scale_interleaved = block_scale_interleave( - w2_s.view(torch.uint8) - ).reshape(orig_shape) - - layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False) - layer.w13_weight_scale = Parameter( - w13_scale_interleaved, requires_grad=False - ) - layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False) - layer.w2_weight_scale = Parameter( - w2_scale_interleaved, requires_grad=False - ) - elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: - - def _interleave_mxfp4_cutlass_sm90(w): - w_shape = w.shape - w_interleaved = w.reshape( - w_shape[0], w_shape[1], (w_shape[2] // 4), 4 - ) - w_interleaved = w_interleaved.permute(0, 2, 1, 3) - w_interleaved = w_interleaved.reshape( - w_shape[0], w_shape[2] // 4, w_shape[1] * 4 - ) - return w_interleaved - - w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8) - w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales) - - w2_weight_scale = layer.w2_weight_scale.data - w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) - w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales) - - layer.w13_weight = torch.nn.Parameter( - torch.cat([w3_w, w1_w], dim=1), requires_grad=False - ) - layer.w13_bias = torch.nn.Parameter( - w13_bias_swapped, requires_grad=False - ) - layer.w13_weight_scale = torch.nn.Parameter( - w31_scales_interleaved, requires_grad=False - ) - layer.w2_weight_scale = torch.nn.Parameter( - w2_scales_interleaved, requires_grad=False - ) - - # theses two kernels go through the `flashinfer_cutlass_fused_moe` path - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, - ) + if w13_bias is not None and w2_bias is not None: + replace_parameter(layer, "w13_bias", w13_bias) + replace_parameter(layer, "w2_bias", w2_bias) - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - assert self.moe_quant_config is not None - prepare_finalize = maybe_make_prepare_finalize( - moe=self.moe, - quant_config=self.moe_quant_config, + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + # TRITON_MONILITHIC need configs but can't make mk + if ( + self.moe_quant_config + and self.mxfp4_backend != Mxfp4MoeBackend.TRITON_MONOLITHIC + ): + assert self.experts_cls is not None + self.moe_mk = make_mxfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + mxfp4_backend=self.mxfp4_backend, + experts_cls=self.experts_cls, routing_tables=layer._maybe_init_expert_routing_tables(), - allow_new_interface=True, - ) - assert prepare_finalize is not None - - self.moe_mk = mk.FusedMoEModularKernel( - prepare_finalize, - FlashInferExperts( - moe_config=self.moe, - quant_config=self.moe_quant_config, - ), - shared_experts=None, + shared_experts=layer.shared_experts, ) - elif self.mxfp4_backend == Mxfp4Backend.TRITON: - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - w13_bias = layer.w13_bias.to(torch.float32) - w2_bias = layer.w2_bias.to(torch.float32) - - layer.w13_bias = Parameter(w13_bias, requires_grad=False) - layer.w2_bias = Parameter(w2_bias, requires_grad=False) - - # Ideally we'd use FusedMoEModularKernel.prepare_finalize object - # (stored in self.fused_experts) to determine if the MoE has a - # batched activation format. As self.fused_experts is not - # initialized at this point, we resort to checking the MoE config - # directly. - is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels - if is_batched_moe: - num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 - else: - num_warps = 8 - - w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( - layer.w13_weight, layer.w13_weight_scale, num_warps - ) - w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( - layer.w2_weight, layer.w2_weight_scale, num_warps - ) + def process_weights_after_loading(self, layer): + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + + self._setup_kernel( + layer, + w13, + w2, + w13_scale, + w2_scale, + w13_bias, + w2_bias, + ) - self.w13_precision_config = PrecisionConfig( - weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) - ) - self.w2_precision_config = PrecisionConfig( - weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) - ) - - self.w13_weight = w13_weight - self.w2_weight = w2_weight - del layer.w13_weight - del layer.w2_weight - layer.w13_weight = w13_weight - layer.w2_weight = w2_weight - else: - raise ValueError( - f"Unsupported mxfp4_backend: {self.mxfp4_backend}: " - f"should be one of: {list(Mxfp4Backend)}." - ) + layer._already_called_process_weights_after_loading = True def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.mxfp4_backend == Mxfp4Backend.MARLIN: - return mxfp4_w4a16_moe_quant_config( - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ) - elif self.mxfp4_backend == Mxfp4Backend.TRITON: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w1_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + if self.mxfp4_backend in ( + Mxfp4MoeBackend.TRITON, + Mxfp4MoeBackend.TRITON_UNFUSED, + Mxfp4MoeBackend.TRITON_MONOLITHIC, + ): + assert self.w13_precision_config is not None + assert self.w2_precision_config is not None + w1_scale = self.w13_precision_config w2_scale = self.w2_precision_config - return mxfp4_w4a16_moe_quant_config( - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) - elif self.mxfp4_backend in [ - Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, - Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS, - ]: - return mxfp4_mxfp8_moe_quant_config( - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ) - elif self.mxfp4_backend in [ - Mxfp4Backend.SM100_FI_MXFP4_BF16, - Mxfp4Backend.SM90_FI_MXFP4_BF16, - ]: - return mxfp4_w4a16_moe_quant_config( - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ) - else: - w1_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale - return ocp_mx_moe_quant_config( - quant_dtype="mxfp4", - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) + + return make_mxfp4_moe_quant_config( + mxfp4_backend=self.mxfp4_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - if ( - prepare_finalize.activation_format - == mk.FusedMoEActivationFormat.BatchedExperts - ): - if self.mxfp4_backend == Mxfp4Backend.MARLIN: - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - assert self.moe_quant_config is not None - return BatchedMarlinExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - moe_config=self.moe, - ) - else: - raise NotImplementedError( - f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for " - "EP batched experts format" - ) - else: - assert self.moe_quant_config is not None - if ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 - ): - # B200 code-path - kwargs = { - # TODO(bnell): part of quant_config - "max_capture_size": self.max_capture_size, - } - return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) - elif self.mxfp4_backend == Mxfp4Backend.MARLIN: - return MarlinExperts(self.moe, self.moe_quant_config) - elif self.mxfp4_backend == Mxfp4Backend.TRITON: - if self.moe.is_lora_enabled: - return UnfusedOAITritonExperts(self.moe, self.moe_quant_config) - return OAITritonExperts(self.moe, self.moe_quant_config) - else: - raise NotImplementedError( - f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" - ) + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) @property def is_monolithic(self) -> bool: - return ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 - or self.mxfp4_backend == Mxfp4Backend.TRITON + return self.mxfp4_backend in ( + Mxfp4MoeBackend.TRITON_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, ) def apply( @@ -962,12 +428,6 @@ def apply( layer.eplb_state.logical_replica_count, ), "MXFP4 are not supported with this configuration." - assert ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - or self.mxfp4_backend == Mxfp4Backend.MARLIN - ) - assert self.moe_mk is not None return self.moe_mk( hidden_states=x, @@ -1008,17 +468,23 @@ def apply_monolithic( layer.eplb_state.logical_replica_count, ), "MXFP4 are not supported with this configuration." - if ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + if self.mxfp4_backend in ( + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC, + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, ): from flashinfer import trtllm_fp4_block_scale_moe - if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16: + if ( + self.mxfp4_backend + == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC + ): assert x.dtype == torch.bfloat16 x_quant = x x_scale = None - elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: + elif ( + self.mxfp4_backend + == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC + ): from flashinfer import mxfp8_quantize x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 @@ -1054,7 +520,7 @@ def apply_monolithic( tune_max_num_tokens=max(self.max_capture_size, 1), )[0] return trtllm_gen_output - elif self.mxfp4_backend == Mxfp4Backend.TRITON: + elif self.mxfp4_backend == Mxfp4MoeBackend.TRITON_MONOLITHIC: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward, ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8394857cf9c2..f96d046e7e21 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -25,9 +25,9 @@ ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Backend, - get_mxfp4_backend, +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + select_mxfp4_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, @@ -674,9 +674,9 @@ def __init__( f"Please check that the combination is supported in OCP_MX_Scheme." ) - self.mxfp4_backend: Mxfp4Backend | None = None + self.mxfp4_backend: Mxfp4MoeBackend | None = None if self.ocp_mx_scheme == "w_mxfp4": - self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) + self.mxfp4_backend, experts_cls = select_mxfp4_moe_backend(moe) if self.input_quant is not None: self.static_input_scales = not self.input_quant.get("is_dynamic") @@ -991,9 +991,10 @@ def apply( shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if not self.emulate: - if ( - self.model_type == "gpt_oss" - and self.mxfp4_backend == Mxfp4Backend.TRITON + if self.model_type == "gpt_oss" and self.mxfp4_backend in ( + Mxfp4MoeBackend.TRITON, + Mxfp4MoeBackend.TRITON_UNFUSED, + Mxfp4MoeBackend.TRITON_MONOLITHIC, ): raise NotImplementedError( "Triton kernel implemented fused MoE for GPT_OSS model " diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 41d5293938fd..2a85aca9e82c 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -332,28 +332,42 @@ def premute_scales( return w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2 -def prepare_moe_fp4_layer_for_marlin( - layer: torch.nn.Module, input_dtype: torch.dtype | None = None -) -> None: +def prepare_moe_mxfp4_layer_for_marlin( + layer: torch.nn.Module, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, +]: logger.warning_once( "Your GPU does not have native support for FP4 computation but " "FP4 quantization is being used. Weight-only FP4 compression will " "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads." ) + input_dtype = get_marlin_input_dtype(prefix="") - is_nvfp4 = hasattr(layer, "w13_weight_scale_2") - if input_dtype is not None and input_dtype.itemsize == 1: - if is_nvfp4: - raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.") - elif input_dtype != torch.float8_e4m3fn: - raise RuntimeError("MXFP4 weight + INT8 activation is not supported.") + if ( + input_dtype is not None + and input_dtype.itemsize == 1 + and input_dtype != torch.float8_e4m3fn + ): + raise RuntimeError("MXFP4 weight + INT8 activation is not supported.") - group_size = 16 if is_nvfp4 else 32 + group_size = 32 - e = layer.num_experts - k = layer.hidden_size - n = layer.intermediate_size_per_partition + e = layer.moe_config.num_experts + k = layer.moe_config.hidden_size + n = layer.moe_config.intermediate_size_per_partition # WORKSPACE device = layer.w13_weight.device @@ -364,8 +378,8 @@ def prepare_moe_fp4_layer_for_marlin( # WEIGHT # Repack weights to marlin format - for name in ["w13_weight", "w2_weight"]: - weight = getattr(layer, name) + + def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor: tensor_list = [] if "w13" in name: size_n, size_k = n * 2, k @@ -388,19 +402,17 @@ def prepare_moe_fp4_layer_for_marlin( tensor_list.append(marlin_qweight) weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - weight = torch.nn.Parameter(weight, requires_grad=False) - setattr(layer, name, weight) + return weight + + w13 = repack_weight(w13, "w13") + w2 = repack_weight(w2, "w2") # WEIGHT SCALES # Permute scales - for name in ["w13", "w2"]: - scales = getattr(layer, name + "_weight_scale") - if not is_nvfp4: - scales = scales.view(torch.float8_e8m0fnu) + def premute_scales(scales: torch.Tensor, name: str) -> torch.Tensor: + scales = scales.view(torch.float8_e8m0fnu) scales = scales.to(param_dtype) - if is_nvfp4: - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -418,30 +430,18 @@ def prepare_moe_fp4_layer_for_marlin( group_size=group_size, is_a_8bit=is_a_8bit, ) - if is_nvfp4: - marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - else: - marlin_scales = mxfp4_marlin_process_scales( - marlin_scales, input_dtype=input_dtype - ) + + marlin_scales = mxfp4_marlin_process_scales( + marlin_scales, input_dtype=input_dtype + ) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - scales = torch.nn.Parameter(scales, requires_grad=False) - setattr(layer, name + "_weight_scale", scales) - if is_nvfp4: - global_scale = nvfp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, requires_grad=False) - setattr(layer, name + "_weight_scale_2", global_scale) - - # BIAS - # Permute bias - for name in ["w13_bias", "w2_bias"]: - if not hasattr(layer, name): - continue - bias = getattr(layer, name).to(param_dtype) + w13_scale = premute_scales(w13_scale, "w13") + w2_scale = premute_scales(w2_scale, "w2") + def premute_bias(bias: torch.Tensor) -> torch.Tensor: tensor_list = [] for i in range(e): expert_bias = bias[i] @@ -449,8 +449,12 @@ def prepare_moe_fp4_layer_for_marlin( tensor_list.append(marlin_permute_bias(expert_bias)) bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - bias = torch.nn.Parameter(bias, requires_grad=False) - setattr(layer, name, bias) + + if w13_bias is not None and w2_bias is not None: + w13_bias = premute_bias(w13_bias) + w2_bias = premute_bias(w2_bias) + + return w13, w2, w13_scale, w2_scale, w13_bias, w2_bias def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None): diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 9dbfc6ecad7b..318071b0300e 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -15,7 +15,7 @@ logger = init_logger(__name__) -def _swizzle_mxfp4(quant_tensor, scale, num_warps): +def _swizzle_mxfp4(quant_tensor, scale, num_warps=8): """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" assert has_triton_kernels() import triton_kernels.matmul_ogs_details.opt_flags as opt_flags