From 7018937e83cbb39c5cf9a3d7f47f2f52be11c13c Mon Sep 17 00:00:00 2001 From: Li Date: Sun, 22 Mar 2026 03:45:26 -0700 Subject: [PATCH] [ROCm][Perf] Add MXFP4 linear method and enable shared expert fusion Implement Mxfp4LinearMethod to replace the UnquantizedLinearMethod fallback for MXFP4-quantized linear layers, addressing the TODO in the existing code. On ROCm, this uses AITER Triton FP4 GEMM (gemm_afp4wfp4) with dynamic activation quantization (matching the ATOM kernel path). On CUDA, it uses the Marlin FP4 kernel. Also enable VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS by default to match ATOM optimized defaults for MoE model performance. Made-with: Cursor Signed-off-by: Li Made-with: Cursor Signed-off-by: Li Made-with: Cursor --- vllm/envs.py | 6 +- .../layers/quantization/mxfp4.py | 122 +++++++++++++++++- 2 files changed, 122 insertions(+), 6 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 2f93b2cb3e0d..aeb91776c5f9 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -111,7 +111,7 @@ VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False - VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True @@ -982,9 +982,9 @@ def _get_or_set_default() -> str: in ("true", "1") ), # Whether to use aiter fusion shared experts ops. - # By default is disabled. + # Enabled by default for better MoE performance (matching ATOM defaults). "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower() + os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower() in ("true", "1") ), # Whether to use aiter triton kernels for gemm ops. diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 22077be8a44b..514ed92344d2 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -3,6 +3,7 @@ import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -25,13 +26,25 @@ mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) -from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform @@ -73,9 +86,21 @@ def get_quant_method( fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + logger.info_once( + "Using AITER MXFP4 linear method on ROCm.", + scope="local", + ) + return Mxfp4LinearMethod() + if current_platform.is_cuda(): + logger.info_once( + "Using Marlin MXFP4 linear method on CUDA.", + scope="local", + ) + return Mxfp4LinearMethod() logger.debug_once( - "MXFP4 linear layer is not implemented - falling back to " - "UnquantizedLinearMethod.", + "MXFP4 linear layer is not supported on this platform " + "- falling back to UnquantizedLinearMethod.", scope="local", ) return UnquantizedLinearMethod() @@ -97,6 +122,97 @@ def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: return True +class Mxfp4LinearMethod(LinearMethodBase): + """MXFP4 quantized linear method. + + On ROCm: Uses AITER's Triton FP4 GEMM (gemm_afp4wfp4) with dynamic + activation quantization, following the same kernel path as ATOM. + On CUDA: Uses the Marlin FP4 kernel. + """ + + MXFP4_BLOCK_SIZE = 32 + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=2, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.MXFP4_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + # Transpose scale so that triton_fp4_gemm_dynamic_qaunt's + # internal .T produces the [N, K/32] layout the kernel expects. + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data.T.contiguous(), requires_grad=False + ) + else: + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + prepare_fp4_layer_for_marlin(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + out = rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt( + x, layer.weight, layer.weight_scale, torch.bfloat16 + ) + if bias is not None: + out = out + bias + return out + else: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=None, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + class Mxfp4MoEMethod(FusedMoEMethodBase): """MXFP4 MoE quantization method."""