Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ class rocm_aiter_ops:
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
Expand All @@ -861,6 +862,7 @@ def refresh_env_variables(cls):
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
Expand Down Expand Up @@ -916,6 +918,11 @@ def is_triton_unified_attn_enabled(cls) -> bool:
def is_fp8bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED

@classmethod
@if_aiter_supported
def is_fp4bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED

@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
Expand Down Expand Up @@ -1396,6 +1403,29 @@ def triton_rotary_embed(
query = query.view(query_shape)
key = key.view(key_shape)

@staticmethod
def batched_gemm_a16wfp4(
X: torch.Tensor,
W: torch.Tensor,
w_scale: torch.Tensor,
Y: torch.Tensor,
transpose_bm: bool | None = False,
prequant: bool | None = False,
y_scale: torch.Tensor | None = None,
) -> torch.Tensor:
# ruff: noqa: E501 # isort: skip
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4

return batched_gemm_a16wfp4(
X,
W,
w_scale,
y=Y,
transpose_bm=transpose_bm,
prequant=prequant,
y_scale=y_scale,
)

@staticmethod
def triton_fp8_bmm(
X: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
Expand Down Expand Up @@ -990,6 +991,11 @@ def get_vllm_port() -> int | None:
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
),
# Whether to use aiter triton fp4 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP4BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1")
),
# Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
Expand Down
127 changes: 47 additions & 80 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,12 @@ def __init__(
self.q_pad_num_heads = q_pad_num_heads
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
self.is_aiter_triton_fp4_bmm_enabled = (
rocm_aiter_ops.is_fp4bmm_enabled()
and self.kv_b_proj.weight.dtype == torch.bfloat16
)

def process_weights_after_loading(self, act_dtype: torch.dtype):
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
Expand Down Expand Up @@ -1212,7 +1218,21 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)

if self.is_aiter_triton_fp8_bmm_enabled:
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
if self.is_aiter_triton_fp4_bmm_enabled:
from vllm.model_executor.layers.quantization.quark.utils import (
quark_quantize_weight_to_mxfp4,
)

self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK)
# Convert from (L, N, P) to (N, L, P)
self.W_K = self.W_K.transpose(0, 1)
self.W_K_scale = self.W_K_scale.transpose(0, 1)

self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4(
W_UV.permute(1, 2, 0)
)
elif self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
Expand Down Expand Up @@ -1262,16 +1282,26 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

if self.is_aiter_triton_fp8_bmm_enabled:
out = out.view(-1, self.num_heads, self.v_head_dim)
out = out.view(-1, self.num_heads, self.v_head_dim)
if self.is_aiter_triton_fp4_bmm_enabled:
out = rocm_aiter_ops.batched_gemm_a16wfp4(
x,
self.W_V,
self.W_V_scale,
out,
transpose_bm=True,
prequant=True,
y_scale=None,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base class missing FP4BMM attribute setup in process_weights_after_loading

Medium Severity

MLACommonBaseImpl._v_up_proj was modified to use self.W_V and self.W_V_scale when is_aiter_triton_fp4_bmm_enabled is True, but MLACommonBaseImpl.process_weights_after_loading only sets these attributes for FP8BMM (in the if self.is_aiter_triton_fp8_bmm_enabled branch). When FP4BMM is enabled but FP8BMM is disabled, process_weights_after_loading falls into the else branch which sets W_UV and W_UK_T instead, leaving W_V and W_V_scale undefined. This causes the base class to be in an inconsistent state. While MLACommonImpl handles FP4BMM setup correctly, any direct subclass of MLACommonBaseImpl would hit an AttributeError at runtime.

Additional Locations (1)

Fix in Cursor Fix in Web

x = out.view(-1, self.num_heads * self.v_head_dim)
elif self.is_aiter_triton_fp8_bmm_enabled:
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
out = out.transpose(0, 1)

# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
Expand Down Expand Up @@ -1578,80 +1608,6 @@ def _run_prefill_context_chunk_trtllm_ragged(
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()

def process_weights_after_loading(self, act_dtype: torch.dtype):
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)

W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)

if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype()
)
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
W_V, dtype=current_platform.fp8_dtype()
)

# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
pre_compilation_list = list(range(1, max_batch_size + 1))
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
total=max_batch_size,
)

for m in pre_compilation_list:
x = torch.empty(
(self.W_K.shape[0], m, self.W_K.shape[2]),
dtype=torch.bfloat16,
device=self.W_K.device,
)
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)

x = torch.empty(
(self.W_V.shape[0], m, self.W_V.shape[2]),
dtype=torch.bfloat16,
device=self.W_V.device,
)
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)

def _concat_k_nope_k_pe(
self, k_nope: torch.Tensor, k_pe: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -2032,7 +1988,18 @@ def forward(
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded

if self.is_aiter_triton_fp8_bmm_enabled:
if self.is_aiter_triton_fp4_bmm_enabled:
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4

decode_ql_nope = batched_gemm_a16wfp4(
decode_q_nope,
self.W_K,
self.W_K_scale,
transpose_bm=True,
prequant=True,
y_scale=layer._q_scale if fp8_attention else None,
)
elif self.is_aiter_triton_fp8_bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
decode_q_nope,
Expand Down
14 changes: 14 additions & 0 deletions vllm/model_executor/layers/quantization/quark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import regex as re
import torch


def deep_compare(dict1: Any, dict2: Any) -> bool:
Expand Down Expand Up @@ -103,3 +104,16 @@ def _is_equal_or_regex_match(
elif target == value:
return True
return False


# utility for tensor dims > 2 cases
def quark_quantize_weight_to_mxfp4(w: torch.Tensor):
assert w.dtype == torch.bfloat16, (
"Quark dynamic quantization is supported only for fp16 weights and only to MXF4"
)

from aiter.ops.triton.quant import dynamic_mxfp4_quant

*dims, d = w.shape
w, w_scales = dynamic_mxfp4_quant(w.reshape(-1, d))
return w.view(*dims, d // 2), w_scales.view(*dims, d // 32)