diff --git a/vllm/envs.py b/vllm/envs.py index ec8d663141a6..13830e640636 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -118,7 +118,7 @@ VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False - VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True @@ -1016,9 +1016,9 @@ def _get_or_set_default() -> str: in ("true", "1") ), # Whether to use aiter fusion shared experts ops. - # By default is disabled. + # Enabled by default for better MoE performance (matching ATOM defaults). "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower() + os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower() in ("true", "1") ), # Whether to use aiter triton kernels for gemm ops. diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index c69e99a68126..29343ae4e294 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -3,6 +3,7 @@ import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -25,13 +26,25 @@ mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) -from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.model_executor.utils import replace_parameter, set_weight_attrs logger = init_logger(__name__) @@ -72,9 +85,21 @@ def get_quant_method( fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + logger.info_once( + "Using AITER MXFP4 linear method on ROCm.", + scope="local", + ) + return Mxfp4LinearMethod() + if current_platform.is_cuda(): + logger.info_once( + "Using Marlin MXFP4 linear method on CUDA.", + scope="local", + ) + return Mxfp4LinearMethod() logger.debug_once( - "MXFP4 linear layer is not implemented - falling back to " - "UnquantizedLinearMethod.", + "MXFP4 linear layer is not supported on this platform " + "- falling back to UnquantizedLinearMethod.", scope="local", ) return UnquantizedLinearMethod() @@ -93,6 +118,97 @@ def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: return True +class Mxfp4LinearMethod(LinearMethodBase): + """MXFP4 quantized linear method. + + On ROCm: Uses AITER's Triton FP4 GEMM (gemm_afp4wfp4) with dynamic + activation quantization, following the same kernel path as ATOM. + On CUDA: Uses the Marlin FP4 kernel. + """ + + MXFP4_BLOCK_SIZE = 32 + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=2, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.MXFP4_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + # Transpose scale so that triton_fp4_gemm_dynamic_qaunt's + # internal .T produces the [N, K/32] layout the kernel expects. + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data.T.contiguous(), requires_grad=False + ) + else: + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + prepare_fp4_layer_for_marlin(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled(): + out = rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt( + x, layer.weight, layer.weight_scale, torch.bfloat16 + ) + if bias is not None: + out = out + bias + return out + else: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=None, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + class Mxfp4MoEMethod(FusedMoEMethodBase): """MXFP4 MoE quantization method."""