Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT

if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_per_tensor_static_quant

if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT:
import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
else:
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
Expand Down Expand Up @@ -324,18 +332,45 @@ def forward(
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
scale = self.self_attn.qkv_proj.input_scale
if scale is not None and VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT:
# Static FP8 quantization
weight = self.input_layernorm.weight
eps = self.input_layernorm.variance_epsilon
if residual is None:
residual = hidden_states
hidden_states, _, _, _ = fused_rms_fp8_per_tensor_static_quant(hidden_states, weight, eps, scale,
None, None, eps,
dtype_quant=rocm_aiter_fp8_dtype,
res1=None)
else:
hidden_states, _, _, residual = fused_rms_fp8_per_tensor_static_quant(hidden_states, weight, eps, scale,
None, None, eps,
dtype_quant=rocm_aiter_fp8_dtype,
res1=residual)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
scale = self.mlp.gate_up_proj.input_scale
if scale is not None and VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT:
# Static FP8 quantization
weight = self.post_attention_layernorm.weight
eps = self.post_attention_layernorm.variance_epsilon
hidden_states, _, _, residual = fused_rms_fp8_per_tensor_static_quant(hidden_states, weight, eps, scale,
None, None, eps,
dtype_quant=rocm_aiter_fp8_dtype,
res1=residual)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

Expand Down
Loading