diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index b443f773525a..d06f32435123 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -836,6 +836,7 @@ class rocm_aiter_ops: _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE @@ -861,6 +862,7 @@ def refresh_env_variables(cls): cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS @@ -916,6 +918,11 @@ def is_triton_unified_attn_enabled(cls) -> bool: def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + @classmethod + @if_aiter_supported + def is_fp4bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4BMM_ENABLED + @classmethod @if_aiter_supported def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: @@ -1396,6 +1403,29 @@ def triton_rotary_embed( query = query.view(query_shape) key = key.view(key_shape) + @staticmethod + def batched_gemm_a16wfp4( + X: torch.Tensor, + W: torch.Tensor, + w_scale: torch.Tensor, + Y: torch.Tensor, + transpose_bm: bool | None = False, + prequant: bool | None = False, + y_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + + return batched_gemm_a16wfp4( + X, + W, + w_scale, + y=Y, + transpose_bm=transpose_bm, + prequant=prequant, + y_scale=y_scale, + ) + @staticmethod def triton_fp8_bmm( X: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index d77c1e9d95e2..53fcd7d583e2 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -121,6 +121,7 @@ VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True @@ -990,6 +991,11 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") ), + # Whether to use aiter triton fp4 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1") + ), # Use AITER triton unified attention for V1 attention "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index aa2740caca71..fc2a4fad9c66 100755 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1184,6 +1184,12 @@ def __init__( self.q_pad_num_heads = q_pad_num_heads self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + self.is_aiter_triton_fp4_bmm_enabled = ( + rocm_aiter_ops.is_fp4bmm_enabled() + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) + def process_weights_after_loading(self, act_dtype: torch.dtype): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform @@ -1212,7 +1218,21 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if self.is_aiter_triton_fp8_bmm_enabled: + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import ( + quark_quantize_weight_to_mxfp4, + ) + + self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) + # Convert from (L, N, P) to (N, L, P) + self.W_K = self.W_K.transpose(0, 1) + self.W_K_scale = self.W_K_scale.transpose(0, 1) + + self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( + W_UV.permute(1, 2, 0) + ) + elif self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1262,16 +1282,26 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - - if self.is_aiter_triton_fp8_bmm_enabled: - out = out.view(-1, self.num_heads, self.v_head_dim) + out = out.view(-1, self.num_heads, self.v_head_dim) + if self.is_aiter_triton_fp4_bmm_enabled: + out = rocm_aiter_ops.batched_gemm_a16wfp4( + x, + self.W_V, + self.W_V_scale, + out, + transpose_bm=True, + prequant=True, + y_scale=None, + ) + x = out.view(-1, self.num_heads * self.v_head_dim) + elif self.is_aiter_triton_fp8_bmm_enabled: # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out ) else: # Convert from (B, N * V) to (N, B, V) - out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + out = out.transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" @@ -1578,80 +1608,6 @@ def _run_prefill_context_chunk_trtllm_ragged( # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() - def process_weights_after_loading(self, act_dtype: torch.dtype): - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights( - self.kv_b_proj, out_dtype=act_dtype - ).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - ), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}" - ) - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - - if self.is_aiter_triton_fp8_bmm_enabled: - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype() - ) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype() - ) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - def _concat_k_nope_k_pe( self, k_nope: torch.Tensor, k_pe: torch.Tensor ) -> torch.Tensor: @@ -2032,7 +1988,18 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if self.is_aiter_triton_fp8_bmm_enabled: + if self.is_aiter_triton_fp4_bmm_enabled: + from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + + decode_ql_nope = batched_gemm_a16wfp4( + decode_q_nope, + self.W_K, + self.W_K_scale, + transpose_bm=True, + prequant=True, + y_scale=layer._q_scale if fp8_attention else None, + ) + elif self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index dc82f94ebbbf..98ac1a4f355e 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -6,6 +6,7 @@ from typing import Any import regex as re +import torch def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -103,3 +104,16 @@ def _is_equal_or_regex_match( elif target == value: return True return False + + +# utility for tensor dims > 2 cases +def quark_quantize_weight_to_mxfp4(w: torch.Tensor): + assert w.dtype == torch.bfloat16, ( + "Quark dynamic quantization is supported only for fp16 weights and only to MXF4" + ) + + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + *dims, d = w.shape + w, w_scales = dynamic_mxfp4_quant(w.reshape(-1, d)) + return w.view(*dims, d // 2), w_scales.view(*dims, d // 32)