diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 77d8e756026d..bda115982cc2 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -90,7 +90,7 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 def forward( - self, x: torch.Tensor + self, x: torch.Tensor, x_scale: torch.Tensor | None = None ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: import vllm._custom_ops as ops diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3d0430c315cf..2dd5406f2a71 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -27,6 +27,7 @@ ) from vllm.model_executor.layers.utils import ( dispatch_unquantized_gemm, + is_layer_moe_router_gate, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -222,8 +223,16 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: - if vllm_is_batch_invariant() and current_platform.is_cuda_alike(): + assert input_scale is None, ( + "UnquantizedLinearMethod does not support input_scale" + ) + if ( + vllm_is_batch_invariant() + and current_platform.is_cuda_alike() + and is_layer_moe_router_gate(getattr(layer, "prefix", "")) + ): return linear_batch_invariant(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -384,11 +393,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def forward( self, x: torch.Tensor, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) + output = self.quant_method.apply(self, x, bias, input_scale=x_scale) if not self.return_bias: return output @@ -574,12 +584,15 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) + output_parallel = self.quant_method.apply( + self, input_, bias, input_scale=x_scale + ) if self.gather_output and self.tp_size > 1: # All-gather across the partitions. @@ -1498,6 +1511,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: if self.input_is_parallel: input_parallel = input_ @@ -1509,10 +1523,12 @@ def forward( # Matrix multiply. assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) + # Only fuse bias add into GEMM for rank 0 (ensures bias not + # added multiple times in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, input_parallel, bias_) + output_parallel = self.quant_method.apply( + self, input_parallel, bias_, input_scale=x_scale + ) if self.reduce_results and self.tp_size > 1: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 1d3e987b7e17..c45641827728 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -5,10 +5,97 @@ import torch from vllm.config import CacheConfig +from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.quantization import QuantizationConfig +logger = init_logger(__name__) + +# Import AITER ops for fused RMSNorm + FP8 quantization +try: + from aiter import dtypes + from aiter.jit.utils.torch_guard import torch_compile_guard + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + _AITER_AVAILABLE = True +except ImportError: + _AITER_AVAILABLE = False + dtypes = None + torch_compile_guard = None + fused_rms_fp8_group_quant = None + + +def _fused_rms_fp8_group_quant_fake( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile/CUDA graphs.""" + if dtype_quant is None: + dtype_quant = dtypes.fp8 + m, n1 = q_c.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=q_c.device) + out1_bs = torch.empty( + (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=q_c.device + ) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out2 = torch.empty_like(kv_c) + return out1_quantized, out1_bs, out2 + + +def _fuse_rmsnorm_quant_impl( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused dual RMSNorm + FP8 quantization using AITER. + + Fuses RMSNorm on q_c with FP8 group quantization, and RMSNorm on kv_c + without quantization. + + Returns: + (q_c_quantized, q_c_scale, kv_c_normed) + """ + (q_c_quantized, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant( + q_c, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_c, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + group_size, + dtype_quant, + None, + output_unquantized_inp1, + transpose_scale, + ) + return q_c_quantized, q_c_scale, kv_c_normed + + +# Apply torch_compile_guard decorator when AITER is available +if _AITER_AVAILABLE: + _fuse_rmsnorm_quant = torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake)( + _fuse_rmsnorm_quant_impl + ) +else: + _fuse_rmsnorm_quant = _fuse_rmsnorm_quant_impl + @dataclass class MLAModules: @@ -110,6 +197,23 @@ def __init__( self.prefix = prefix + # Enable RMSNorm+Quant fusion when AITER is available with FP8 + self.quant_config = quant_config + self.quant_dtype = None + self.fuse_qknorm_quant = False + + if _AITER_AVAILABLE and quant_config is not None: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + if isinstance(quant_config, Fp8Config): + self.quant_dtype = dtypes.fp8 + self.fuse_qknorm_quant = True + logger.info( + "[MLA_FUSION_INIT] Fusion enabled for %s: " + "AITER available and FP8 quantization detected", + prefix, + ) + def forward( self, positions: torch.Tensor, @@ -118,6 +222,7 @@ def forward( ) -> torch.Tensor: q_c = None kv_lora = None + q_c_scale = None # Set when fuse_qknorm_quant is enabled if self.q_lora_rank is not None: assert self.fused_qkv_a_proj is not None, ( @@ -130,13 +235,37 @@ def forward( "q_b_proj is required when q_lora_rank is not None" ) + # Step 1: QKV projection (use existing layer) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_lora = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # Step 2: Apply RMSNorm and optional FP8 quantization + if self.fuse_qknorm_quant: + # Fused RMSNorm + FP8 quantization + q_c_quantized, q_c_scale, kv_c_normed = _fuse_rmsnorm_quant( + q_c, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + kv_c, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + dtype_quant=self.quant_dtype, + group_size=128, + output_unquantized_inp1=False, + transpose_scale=True, + ) + q = self.q_b_proj(q_c_quantized, x_scale=q_c_scale)[0] + else: + # Unfused path: RMSNorm only + q_c = self.q_a_layernorm(q_c) + kv_c_normed = self.kv_a_layernorm(kv_c) + q = self.q_b_proj(q_c)[0] else: assert self.kv_a_proj_with_mqa is not None, ( "kv_a_proj_with_mqa is required when q_lora_rank is None" @@ -146,9 +275,10 @@ def forward( ) kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) # Add head dim of 1 to k_pe diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5101347cd02a..e31ba30b2e4e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -448,6 +448,7 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. @@ -509,7 +510,7 @@ def apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=input_scale, bias=bias, ) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 5d7b7b54adc8..e710f322b9bf 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -128,5 +128,8 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: + # Note: PTPC FP8 implementation uses apply_weights which doesn't + # support pre-quantized inputs, so input_scale is ignored. return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 78b1234021af..a740a2c29b44 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -393,7 +393,6 @@ def apply( input_scale: torch.Tensor | None = None, bias: torch.Tensor | None = None, ) -> torch.Tensor: - assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] @@ -404,20 +403,29 @@ def apply( ) and should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # FlashInfer: does not support pre-quantized input + assert input_scale is None, ( + "FlashInfer FP8 blockscale GEMM does not support pre-quantized input" + ) output = self._run_flashinfer(input_2d, weight, weight_scale) elif should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # DeepGEMM: does not support pre-quantized input + assert input_scale is None, ( + "DeepGEMM FP8 linear does not support pre-quantized input" + ) output = self._run_deepgemm(input_2d, weight, weight_scale) else: + # AITER/Triton/Cutlass: supports pre-quantized input output = self.w8a8_blockscale_op( input_2d, weight, weight_scale, input_scale ) if bias is not None: output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + return output.view(*output_shape) def _run_deepgemm( self, @@ -444,9 +452,15 @@ def _run_cutlass( weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + output_dtype = torch.bfloat16 if self.is_hopper: return torch.ops.vllm.padded_cutlass( q_input, @@ -454,7 +468,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) else: return cutlass_scaled_mm( @@ -463,7 +477,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_aiter( @@ -488,9 +502,13 @@ def _run_aiter( gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale if input_scale is not None: + # Use pre-quantized FP8 input directly q_input = input_2d + output_dtype = torch.bfloat16 else: + # Quantize input if not already quantized q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton) + output_dtype = input_2d.dtype return gemm_a8w8_blockscale_op( q_input, @@ -498,7 +516,7 @@ def _run_aiter( input_scale, weight_scale, list(self.weight_group_shape), - output_dtype=input_2d.dtype, + output_dtype=output_dtype, ) def _run_triton( @@ -508,16 +526,22 @@ def _run_triton( weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + output_dtype = torch.bfloat16 return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( q_input, weight, input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_flashinfer(