diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1dc51fcc00ea..900a6aa11c74 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -64,6 +64,14 @@ if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE + VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT + + if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_per_tensor_static_quant + + if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: + import aiter as rocm_aiter + rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 else: VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False @@ -324,18 +332,45 @@ def forward( residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + scale = self.self_attn.qkv_proj.input_scale + if scale is not None and VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: + # Static FP8 quantization + weight = self.input_layernorm.weight + eps = self.input_layernorm.variance_epsilon + if residual is None: + residual = hidden_states + hidden_states, _, _, _ = fused_rms_fp8_per_tensor_static_quant(hidden_states, weight, eps, scale, + None, None, eps, + dtype_quant=rocm_aiter_fp8_dtype, + res1=None) + else: + hidden_states, _, _, residual = fused_rms_fp8_per_tensor_static_quant(hidden_states, weight, eps, scale, + None, None, eps, + dtype_quant=rocm_aiter_fp8_dtype, + res1=residual) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + scale = self.mlp.gate_up_proj.input_scale + if scale is not None and VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: + # Static FP8 quantization + weight = self.post_attention_layernorm.weight + eps = self.post_attention_layernorm.variance_epsilon + hidden_states, _, _, residual = fused_rms_fp8_per_tensor_static_quant(hidden_states, weight, eps, scale, + None, None, eps, + dtype_quant=rocm_aiter_fp8_dtype, + res1=residual) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual