diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index e8ed8a5249d1..b3acc89712cb 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -106,7 +106,7 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 def forward( - self, x: torch.Tensor + self, x: torch.Tensor, x_scale: torch.Tensor | None = None ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: # Tier 1: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 44fd516f5e5c..c36e513463d5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -222,7 +222,11 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: + assert input_scale is None, ( + "UnquantizedLinearMethod does not support input_scale" + ) if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike(): return linear_batch_invariant(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -384,11 +388,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def forward( self, x: torch.Tensor, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) + output = self.quant_method.apply(self, x, bias, input_scale=x_scale) if not self.return_bias: return output @@ -574,12 +579,15 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) + output_parallel = self.quant_method.apply( + self, input_, bias, input_scale=x_scale + ) if self.gather_output and self.tp_size > 1: # All-gather across the partitions. @@ -1512,6 +1520,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: if self.input_is_parallel: input_parallel = input_ @@ -1523,10 +1532,12 @@ def forward( # Matrix multiply. assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) + # Only fuse bias add into GEMM for rank 0 (ensures bias not + # added multiple times in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, input_parallel, bias_) + output_parallel = self.quant_method.apply( + self, input_parallel, bias_, input_scale=x_scale + ) if self.reduce_results and self.tp_size > 1: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 1d3e987b7e17..419ad1382468 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -5,10 +5,93 @@ import torch from vllm.config import CacheConfig +from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.quantization import QuantizationConfig +logger = init_logger(__name__) + +# Import AITER ops for fused RMSNorm + FP8 quantization +try: + from aiter import dtypes + from aiter.jit.utils.torch_guard import torch_compile_guard + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + _AITER_AVAILABLE = True +except ImportError: + _AITER_AVAILABLE = False + dtypes = None + torch_compile_guard = None + fused_rms_fp8_group_quant = None + + +def _fused_rms_fp8_group_quant_fake( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile/CUDA graphs.""" + if dtype_quant is None: + dtype_quant = dtypes.fp8 + m, n1 = q_c.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=q_c.device) + out1_bs = torch.empty( + (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=q_c.device + ) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out2 = torch.empty_like(kv_c) + return out1_quantized, out1_bs, out2 + + +def _fuse_rmsnorm_quant_impl( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused dual RMSNorm + FP8 quantization using AITER. + + Fuses RMSNorm on q_c with FP8 group quantization, and RMSNorm on kv_c + without quantization. + + Returns: + (q_c_quantized, q_c_scale, kv_c_normed) + """ + (q_c_quantized, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant( + q_c, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_c, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + group_size, + dtype_quant, + None, + output_unquantized_inp1, + transpose_scale, + ) + return q_c_quantized, q_c_scale, kv_c_normed + + +# Make fusion transparent to compiler (no @torch_compile_guard) +# This allows the compiler to trace through and batch operations efficiently +_fuse_rmsnorm_quant = _fuse_rmsnorm_quant_impl + @dataclass class MLAModules: @@ -110,6 +193,23 @@ def __init__( self.prefix = prefix + # Enable RMSNorm+Quant fusion when AITER is available with FP8 + self.quant_config = quant_config + self.quant_dtype = None + self.fuse_qknorm_quant = False + + if _AITER_AVAILABLE and quant_config is not None: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + if isinstance(quant_config, Fp8Config): + self.quant_dtype = dtypes.fp8 + self.fuse_qknorm_quant = True + logger.info( + "[MLA_FUSION_INIT] Fusion enabled for %s: " + "AITER available and FP8 quantization detected", + prefix, + ) + def forward( self, positions: torch.Tensor, @@ -118,6 +218,7 @@ def forward( ) -> torch.Tensor: q_c = None kv_lora = None + q_c_scale = None # Set when fuse_qknorm_quant is enabled if self.q_lora_rank is not None: assert self.fused_qkv_a_proj is not None, ( @@ -130,13 +231,37 @@ def forward( "q_b_proj is required when q_lora_rank is not None" ) + # Step 1: QKV projection (use existing layer) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_lora = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # Step 2: Apply RMSNorm and optional FP8 quantization + if self.fuse_qknorm_quant: + # Fused RMSNorm + FP8 quantization + q_c_quantized, q_c_scale, kv_c_normed = _fuse_rmsnorm_quant( + q_c, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + kv_c, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + dtype_quant=self.quant_dtype, + group_size=128, + output_unquantized_inp1=False, + transpose_scale=True, + ) + q = self.q_b_proj(q_c_quantized, x_scale=q_c_scale)[0] + else: + # Unfused path: RMSNorm only + q_c = self.q_a_layernorm(q_c) + kv_c_normed = self.kv_a_layernorm(kv_c) + q = self.q_b_proj(q_c)[0] else: assert self.kv_a_proj_with_mqa is not None, ( "kv_a_proj_with_mqa is required when q_lora_rank is None" @@ -146,9 +271,10 @@ def forward( ) kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) # Add head dim of 1 to k_pe diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 69255a2793cb..9ae347d08feb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -441,6 +441,7 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. @@ -451,7 +452,9 @@ def apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=input_scale + if input_scale is not None + else layer.input_scale, bias=bias, ) else: @@ -488,7 +491,7 @@ def apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=input_scale, bias=bias, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9568d1320bc6..3aa1b77fb1f2 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -399,32 +399,44 @@ def apply( weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, bias: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: - assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype + # Use provided output_dtype, or default based on whether input is + # pre-quantized (bfloat16) or not (input.dtype) + if output_dtype is None: + output_dtype = input.dtype if input_scale is None else torch.bfloat16 if should_use_flashinfer_for_blockscale_fp8_gemm( self.is_flashinfer_supported, output_dtype, input_2d, weight ) and should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # FlashInfer: does not support pre-quantized input + assert input_scale is None, ( + "FlashInfer FP8 blockscale GEMM does not support pre-quantized input" + ) output = self._run_flashinfer(input_2d, weight, weight_scale) elif should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # DeepGEMM: does not support pre-quantized input + assert input_scale is None, ( + "DeepGEMM FP8 linear does not support pre-quantized input" + ) output = self._run_deepgemm(input_2d, weight, weight_scale) else: - output = self.w8a8_blockscale_op( - input_2d, weight, weight_scale, input_scale + # AITER/Triton/Cutlass: supports pre-quantized input + output = self.w8a8_blockscale_op( # type: ignore[call-arg] + input_2d, weight, weight_scale, input_scale, output_dtype ) if bias is not None: output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + return output.view(*output_shape) def _run_deepgemm( self, @@ -450,10 +462,19 @@ def _run_cutlass( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + if output_dtype is None: + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + if output_dtype is None: + output_dtype = torch.bfloat16 if self.is_hopper: return torch.ops.vllm.padded_cutlass( q_input, @@ -461,7 +482,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) else: return cutlass_scaled_mm( @@ -470,7 +491,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_aiter( @@ -479,6 +500,7 @@ def _run_aiter( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) @@ -495,9 +517,15 @@ def _run_aiter( gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale if input_scale is not None: + # Use pre-quantized FP8 input directly q_input = input_2d + if output_dtype is None: + output_dtype = torch.bfloat16 else: + # Quantize input if not already quantized q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton) + if output_dtype is None: + output_dtype = input_2d.dtype return gemm_a8w8_blockscale_op( q_input, @@ -505,7 +533,7 @@ def _run_aiter( input_scale, weight_scale, list(self.weight_group_shape), - output_dtype=input_2d.dtype, + output_dtype=output_dtype, ) def _run_triton( @@ -514,17 +542,26 @@ def _run_triton( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + if output_dtype is None: + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + if output_dtype is None: + output_dtype = torch.bfloat16 return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( q_input, weight, input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_flashinfer(