Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9a4bb0b
Add AMD AITER MLA fusion optimization for DeepSeek models
khairulkabir1661 Feb 27, 2026
dda6084
Add comprehensive tests for MLA fusion on AMD AITER
khairulkabir1661 Mar 2, 2026
1eafc57
Fix pre-commit issues in test_mla_fusion.py
khairulkabir1661 Mar 2, 2026
b020a60
Fix pytest mark warnings in test_mla_fusion.py
khairulkabir1661 Mar 2, 2026
cb71c3b
test: Remove placeholder tests from MLA fusion test suite
khairulkabir1661 Mar 3, 2026
2de8c58
Fix code review issues: improve exception handling and add logging
khairulkabir1661 Mar 4, 2026
d95dd82
Fix MLA fusion: use custom op pattern and clean up tests
khairulkabir1661 Mar 5, 2026
dc47e37
Fix MLA fusion tests: compare FP8-fused vs FP8-unfused
khairulkabir1661 Mar 5, 2026
882debd
Remove test_deterministic_outputs from MLA fusion tests
khairulkabir1661 Mar 5, 2026
d132b95
Fix MLA fusion custom op registration and optimize tests
khairulkabir1661 Mar 5, 2026
59895ac
[ROCm][FP8] Add x_scale parameter support for MLA fusion (Option 2)
khairulkabir1661 Mar 6, 2026
8c38d23
Fix mypy signature compatibility for x_scale/input_scale parameters
khairulkabir1661 Mar 6, 2026
576b09e
Clean up mla.py comments (lines 259-287)
khairulkabir1661 Mar 27, 2026
6ff38d8
Clarify q_c_scale comment in mla.py (line 240)
khairulkabir1661 Mar 27, 2026
16bea47
Clean up fusion init comments (lines 213-231)
khairulkabir1661 Mar 27, 2026
f49af5e
Clean up AITER fusion helper comments (lines 15-112)
khairulkabir1661 Mar 27, 2026
5a505ea
Remove test_mla_fusion.py test file
khairulkabir1661 Mar 27, 2026
b54784a
Clean up comments in fp8_utils.py
khairulkabir1661 Mar 27, 2026
709d0ed
Fix input_scale and output_dtype handling in FP8 quantization
khairulkabir1661 Mar 27, 2026
7e677a7
Remove is_layer_moe_router_gate check from batch invariance
khairulkabir1661 Mar 27, 2026
6f34c3b
Remove @torch_compile_guard to make fusion transparent to compiler
khairulkabir1661 Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None:
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16

def forward(
self, x: torch.Tensor
self, x: torch.Tensor, x_scale: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
Expand Down
21 changes: 16 additions & 5 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None, (
"UnquantizedLinearMethod does not support input_scale"
)
if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
Expand Down Expand Up @@ -384,11 +388,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
def forward(
self,
x: torch.Tensor,
x_scale: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None

output = self.quant_method.apply(self, x, bias)
output = self.quant_method.apply(self, x, bias, input_scale=x_scale)

if not self.return_bias:
return output
Expand Down Expand Up @@ -574,12 +579,15 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor
def forward(
self,
input_,
x_scale: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
output_parallel = self.quant_method.apply(
self, input_, bias, input_scale=x_scale
)

if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
Expand Down Expand Up @@ -1512,6 +1520,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor
def forward(
self,
input_,
x_scale: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
Expand All @@ -1523,10 +1532,12 @@ def forward(

# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
# Only fuse bias add into GEMM for rank 0 (ensures bias not
# added multiple times in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
output_parallel = self.quant_method.apply(
self, input_parallel, bias_, input_scale=x_scale
)

if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
Expand Down
136 changes: 131 additions & 5 deletions vllm/model_executor/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,93 @@
import torch

from vllm.config import CacheConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention import MLAAttention
from vllm.model_executor.layers.quantization import QuantizationConfig

logger = init_logger(__name__)

# Import AITER ops for fused RMSNorm + FP8 quantization
try:
from aiter import dtypes
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant

_AITER_AVAILABLE = True
except ImportError:
_AITER_AVAILABLE = False
dtypes = None
torch_compile_guard = None
fused_rms_fp8_group_quant = None


def _fused_rms_fp8_group_quant_fake(
q_c: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_c: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
dtype_quant: torch.dtype | None = None,
group_size: int = 128,
output_unquantized_inp1: bool = False,
transpose_scale: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile/CUDA graphs."""
if dtype_quant is None:
dtype_quant = dtypes.fp8
m, n1 = q_c.shape
out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=q_c.device)
out1_bs = torch.empty(
(m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=q_c.device
)
if transpose_scale:
out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape)
out2 = torch.empty_like(kv_c)
return out1_quantized, out1_bs, out2


def _fuse_rmsnorm_quant_impl(
q_c: torch.Tensor,
q_a_layernorm_weight: torch.Tensor,
q_a_layernorm_variance_epsilon: float,
kv_c: torch.Tensor,
kv_a_layernorm_weight: torch.Tensor,
kv_a_layernorm_variance_epsilon: float,
dtype_quant: torch.dtype | None = None,
group_size: int = 128,
output_unquantized_inp1: bool = False,
transpose_scale: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Fused dual RMSNorm + FP8 quantization using AITER.

Fuses RMSNorm on q_c with FP8 group quantization, and RMSNorm on kv_c
without quantization.

Returns:
(q_c_quantized, q_c_scale, kv_c_normed)
"""
(q_c_quantized, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant(
q_c,
q_a_layernorm_weight,
q_a_layernorm_variance_epsilon,
kv_c,
kv_a_layernorm_weight,
kv_a_layernorm_variance_epsilon,
group_size,
dtype_quant,
None,
output_unquantized_inp1,
transpose_scale,
)
return q_c_quantized, q_c_scale, kv_c_normed


# Make fusion transparent to compiler (no @torch_compile_guard)
# This allows the compiler to trace through and batch operations efficiently
_fuse_rmsnorm_quant = _fuse_rmsnorm_quant_impl


@dataclass
class MLAModules:
Expand Down Expand Up @@ -110,6 +193,23 @@ def __init__(

self.prefix = prefix

# Enable RMSNorm+Quant fusion when AITER is available with FP8
self.quant_config = quant_config
self.quant_dtype = None
self.fuse_qknorm_quant = False

if _AITER_AVAILABLE and quant_config is not None:
from vllm.model_executor.layers.quantization.fp8 import Fp8Config

if isinstance(quant_config, Fp8Config):
self.quant_dtype = dtypes.fp8
self.fuse_qknorm_quant = True
logger.info(
"[MLA_FUSION_INIT] Fusion enabled for %s: "
"AITER available and FP8 quantization detected",
prefix,
)

def forward(
self,
positions: torch.Tensor,
Expand All @@ -118,6 +218,7 @@ def forward(
) -> torch.Tensor:
q_c = None
kv_lora = None
q_c_scale = None # Set when fuse_qknorm_quant is enabled

if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, (
Expand All @@ -130,13 +231,37 @@ def forward(
"q_b_proj is required when q_lora_rank is not None"
)

# Step 1: QKV projection (use existing layer)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
kv_c, k_pe = kv_lora.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)

# Step 2: Apply RMSNorm and optional FP8 quantization
if self.fuse_qknorm_quant:
# Fused RMSNorm + FP8 quantization
q_c_quantized, q_c_scale, kv_c_normed = _fuse_rmsnorm_quant(
q_c,
self.q_a_layernorm.weight,
self.q_a_layernorm.variance_epsilon,
kv_c,
self.kv_a_layernorm.weight,
self.kv_a_layernorm.variance_epsilon,
dtype_quant=self.quant_dtype,
group_size=128,
output_unquantized_inp1=False,
transpose_scale=True,
)
q = self.q_b_proj(q_c_quantized, x_scale=q_c_scale)[0]
else:
# Unfused path: RMSNorm only
q_c = self.q_a_layernorm(q_c)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = self.q_b_proj(q_c)[0]
else:
assert self.kv_a_proj_with_mqa is not None, (
"kv_a_proj_with_mqa is required when q_lora_rank is None"
Expand All @@ -146,9 +271,10 @@ def forward(
)
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]

kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
kv_c, k_pe = kv_lora.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
kv_c_normed = self.kv_a_layernorm(kv_c)

q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
Expand All @@ -451,7 +452,9 @@ def apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
input_scale=input_scale
if input_scale is not None
else layer.input_scale,
bias=bias,
)
else:
Expand Down Expand Up @@ -488,7 +491,7 @@ def apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
input_scale=input_scale,
bias=bias,
)

Expand Down
Loading
Loading