From 121f4cd705197b9d6bb77a01c0134563ee767b3b Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 13 Jan 2026 05:20:19 +0000 Subject: [PATCH 1/5] initial commit, something wrong with transpositions Signed-off-by: Aleksandr Malyshev --- vllm/_aiter_ops.py | 30 ++++++++++ vllm/envs.py | 6 ++ .../layers/quantization/quark/utils.py | 14 +++++ vllm/v1/attention/backends/mla/common.py | 59 ++++++++++++++++--- 4 files changed, 101 insertions(+), 8 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index b443f773525a..b320274c4f3f 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -836,6 +836,7 @@ class rocm_aiter_ops: _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE @@ -861,6 +862,7 @@ def refresh_env_variables(cls): cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS @@ -916,6 +918,11 @@ def is_triton_unified_attn_enabled(cls) -> bool: def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + @classmethod + @if_aiter_supported + def is_fp4bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4BMM_ENABLED + @classmethod @if_aiter_supported def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: @@ -1396,6 +1403,29 @@ def triton_rotary_embed( query = query.view(query_shape) key = key.view(key_shape) + @staticmethod + def batched_gemm_a16wfp4( + X: torch.Tensor, + W: torch.Tensor, + w_scale: torch.Tensor, + Y: torch.Tensor, + transpose_bm: bool | None = False, + prequant: bool | None = False, + y_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + + return batched_gemm_a16wfp4( + X, + W, + w_scale, + Y, + transpose_bm=transpose_bm, + prequant=prequant, + y_scale=y_scale, + ) + @staticmethod def triton_fp8_bmm( X: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index a9f6123a7d06..5268f9520971 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -121,6 +121,7 @@ VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True @@ -985,6 +986,11 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") ), + # Whether to use aiter triton fp4 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1") + ), # Use AITER triton unified attention for V1 attention "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index dc82f94ebbbf..c2386a89b1b4 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -6,6 +6,7 @@ from typing import Any import regex as re +import torch def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -103,3 +104,16 @@ def _is_equal_or_regex_match( elif target == value: return True return False + + +# utility for tensor dims > 2 cases +def quark_quntize_weight_to_mxfp4(w: torch.Tensor): + assert w.dtype == torch.bfloat16, ( + "Quark dynamic quantization is supported only for fp16 weights and only to MXF4" + ) + + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + *dims, d = w.shape + w, w_scales = dynamic_mxfp4_quant(w.reshape(-1, d)) + return w.view(*dims, d // 2), w_scales.view(*dims, d // 32) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2ee2740a51ba..a6bf021d609f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1183,6 +1183,12 @@ def __init__( self.q_pad_num_heads = q_pad_num_heads self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + self.is_aiter_triton_fp4_bmm_enabled = ( + rocm_aiter_ops.is_fp4bmm_enabled() + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) + def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") @@ -1283,16 +1289,26 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - - if self.is_aiter_triton_fp8_bmm_enabled: - out = out.view(-1, self.num_heads, self.v_head_dim) + out = out.view(-1, self.num_heads, self.v_head_dim) + if self.is_aiter_triton_fp4_bmm_enabled: + out = rocm_aiter_ops.batched_gemm_a16wfp4( + x, + self.W_V, + self.W_V_scale, + y=out, + transpose_bm=True, + prequant=True, + y_scale=None, + ) + x = out.view(-1, self.num_heads * self.v_head_dim) + elif self.is_aiter_triton_fp8_bmm_enabled: # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out ) else: # Convert from (B, N * V) to (N, B, V) - out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + out = out.transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" @@ -1644,12 +1660,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ) - W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if self.is_aiter_triton_fp8_bmm_enabled: + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import ( + quark_quntize_weight_to_mxfp4, + ) + + self.W_K, self.W_K_scale = quark_quntize_weight_to_mxfp4(W_UK) + self.W_V, self.W_V_scale = quark_quntize_weight_to_mxfp4(W_UV) + # Convert from (L, N, P) to (N, P, L) + self.W_K = self.W_K.permute(1, 2, 0) + # Convert from (L, N, V) to (N, L, V) + self.W_V = self.W_V.transpose(0, 1) + elif self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -2045,7 +2072,7 @@ def forward( scale=layer._k_scale, ) - if fp8_attention: + if fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled: kv_cache = kv_cache.view(current_platform.fp8_dtype()) if has_prefill: @@ -2076,7 +2103,23 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if self.is_aiter_triton_fp8_bmm_enabled: + if self.is_aiter_triton_fp4_bmm_enabled: + from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + + decode_q_nope = decode_q_nope.transpose(0, 1) + decode_ql_nope = None + print(f"{decode_q_nope.shape=}") + print(f"{self.W_K.shape=}") + decode_ql_nope = batched_gemm_a16wfp4( + decode_q_nope, + self.W_K, + self.W_K_scale, + y=decode_ql_nope, + transpose_bm=True, + prequant=True, + y_scale=layer._q_scale if fp8_attention else None, + ) + elif self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, From e50c27d1b816a91e977141c600a98e8526b196cb Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 13 Jan 2026 06:02:58 +0000 Subject: [PATCH 2/5] initial commit, something wrong with transpositions Signed-off-by: Aleksandr Malyshev --- vllm/v1/attention/backends/mla/common.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a6bf021d609f..b23b7b760a35 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1291,6 +1291,10 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) out = out.view(-1, self.num_heads, self.v_head_dim) if self.is_aiter_triton_fp4_bmm_enabled: + print(f"{x.shape=}") + print(f"{self.W_V.shape=}") + print(f"{self.W_V_scale.shape=}") + out = rocm_aiter_ops.batched_gemm_a16wfp4( x, self.W_V, @@ -1671,11 +1675,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) self.W_K, self.W_K_scale = quark_quntize_weight_to_mxfp4(W_UK) + # Convert from (L, N, P) to (N, P, L) ??? + self.W_K = self.W_K.transpose(0, 1) + self.W_K_scale = self.W_K_scale.transpose(0, 1) + + print(f"{W_UV.shape=}") self.W_V, self.W_V_scale = quark_quntize_weight_to_mxfp4(W_UV) - # Convert from (L, N, P) to (N, P, L) - self.W_K = self.W_K.permute(1, 2, 0) - # Convert from (L, N, V) to (N, L, V) - self.W_V = self.W_V.transpose(0, 1) + print(f"{self.W_V.shape=}") + print(f"{self.W_V_scale.shape=}") + + # Convert from (L, N, V) to (N, L, V) ??? + self.W_V = self.W_V.permute(1, 2, 0) elif self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 @@ -2106,10 +2116,12 @@ def forward( if self.is_aiter_triton_fp4_bmm_enabled: from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 - decode_q_nope = decode_q_nope.transpose(0, 1) decode_ql_nope = None + print(f"{decode_q_nope.shape=}") print(f"{self.W_K.shape=}") + print(f"{self.W_K_scale.shape=}") + decode_ql_nope = batched_gemm_a16wfp4( decode_q_nope, self.W_K, From a96833e85a6fde345cef9e5e0ba21996ac58e5b1 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 13 Jan 2026 06:38:02 +0000 Subject: [PATCH 3/5] refactored and fixed code, output is reasonable Signed-off-by: Aleksandr Malyshev --- vllm/_aiter_ops.py | 2 +- vllm/v1/attention/backends/mla/common.py | 23 +++++------------------ 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index b320274c4f3f..d06f32435123 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1420,7 +1420,7 @@ def batched_gemm_a16wfp4( X, W, w_scale, - Y, + y=Y, transpose_bm=transpose_bm, prequant=prequant, y_scale=y_scale, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b23b7b760a35..7557174a45bd 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1291,15 +1291,11 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) out = out.view(-1, self.num_heads, self.v_head_dim) if self.is_aiter_triton_fp4_bmm_enabled: - print(f"{x.shape=}") - print(f"{self.W_V.shape=}") - print(f"{self.W_V_scale.shape=}") - out = rocm_aiter_ops.batched_gemm_a16wfp4( x, self.W_V, self.W_V_scale, - y=out, + out, transpose_bm=True, prequant=True, y_scale=None, @@ -1675,17 +1671,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) self.W_K, self.W_K_scale = quark_quntize_weight_to_mxfp4(W_UK) - # Convert from (L, N, P) to (N, P, L) ??? + # Convert from (L, N, P) to (N, L, P) self.W_K = self.W_K.transpose(0, 1) self.W_K_scale = self.W_K_scale.transpose(0, 1) - print(f"{W_UV.shape=}") - self.W_V, self.W_V_scale = quark_quntize_weight_to_mxfp4(W_UV) - print(f"{self.W_V.shape=}") - print(f"{self.W_V_scale.shape=}") - - # Convert from (L, N, V) to (N, L, V) ??? - self.W_V = self.W_V.permute(1, 2, 0) + self.W_V, self.W_V_scale = quark_quntize_weight_to_mxfp4( + W_UV.permute(1, 2, 0) + ) elif self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 @@ -2117,11 +2109,6 @@ def forward( from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 decode_ql_nope = None - - print(f"{decode_q_nope.shape=}") - print(f"{self.W_K.shape=}") - print(f"{self.W_K_scale.shape=}") - decode_ql_nope = batched_gemm_a16wfp4( decode_q_nope, self.W_K, From 02e7e2a2fcb353297e1eae6e060dc051a00d32e2 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 13 Jan 2026 20:25:06 +0000 Subject: [PATCH 4/5] minor review corrections Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/layers/quantization/quark/utils.py | 2 +- vllm/v1/attention/backends/mla/common.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index c2386a89b1b4..98ac1a4f355e 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -107,7 +107,7 @@ def _is_equal_or_regex_match( # utility for tensor dims > 2 cases -def quark_quntize_weight_to_mxfp4(w: torch.Tensor): +def quark_quantize_weight_to_mxfp4(w: torch.Tensor): assert w.dtype == torch.bfloat16, ( "Quark dynamic quantization is supported only for fp16 weights and only to MXF4" ) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 7557174a45bd..54dafd5b9ef5 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1667,15 +1667,15 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported if self.is_aiter_triton_fp4_bmm_enabled: from vllm.model_executor.layers.quantization.quark.utils import ( - quark_quntize_weight_to_mxfp4, + quark_quantize_weight_to_mxfp4, ) - self.W_K, self.W_K_scale = quark_quntize_weight_to_mxfp4(W_UK) + self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) # Convert from (L, N, P) to (N, L, P) self.W_K = self.W_K.transpose(0, 1) self.W_K_scale = self.W_K_scale.transpose(0, 1) - self.W_V, self.W_V_scale = quark_quntize_weight_to_mxfp4( + self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( W_UV.permute(1, 2, 0) ) elif self.is_aiter_triton_fp8_bmm_enabled: @@ -2113,7 +2113,6 @@ def forward( decode_q_nope, self.W_K, self.W_K_scale, - y=decode_ql_nope, transpose_bm=True, prequant=True, y_scale=layer._q_scale if fp8_attention else None, From ae10b68c1ceeb588f32c401338b4d7934e53d862 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Wed, 14 Jan 2026 23:38:30 +0000 Subject: [PATCH 5/5] cursor comments fixes Signed-off-by: Aleksandr Malyshev --- .../layers/attention/mla_attention.py | 106 +++--------------- 1 file changed, 16 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 21ff6b07ba45..fc2a4fad9c66 100755 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1218,7 +1218,21 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if self.is_aiter_triton_fp8_bmm_enabled: + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import ( + quark_quantize_weight_to_mxfp4, + ) + + self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) + # Convert from (L, N, P) to (N, L, P) + self.W_K = self.W_K.transpose(0, 1) + self.W_K_scale = self.W_K_scale.transpose(0, 1) + + self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( + W_UV.permute(1, 2, 0) + ) + elif self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1594,93 +1608,6 @@ def _run_prefill_context_chunk_trtllm_ragged( # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() - def process_weights_after_loading(self, act_dtype: torch.dtype): - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights( - self.kv_b_proj, out_dtype=act_dtype - ).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - ), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}" - ) - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - - # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported - if self.is_aiter_triton_fp4_bmm_enabled: - from vllm.model_executor.layers.quantization.quark.utils import ( - quark_quantize_weight_to_mxfp4, - ) - - self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) - # Convert from (L, N, P) to (N, L, P) - self.W_K = self.W_K.transpose(0, 1) - self.W_K_scale = self.W_K_scale.transpose(0, 1) - - self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( - W_UV.permute(1, 2, 0) - ) - elif self.is_aiter_triton_fp8_bmm_enabled: - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype() - ) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype() - ) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - def _concat_k_nope_k_pe( self, k_nope: torch.Tensor, k_pe: torch.Tensor ) -> torch.Tensor: @@ -2030,7 +1957,7 @@ def forward( scale=layer._k_scale, ) - if fp8_attention and not self.is_aiter_triton_fp4_bmm_enabled: + if fp8_attention: kv_cache = kv_cache.view(current_platform.fp8_dtype()) if has_prefill: @@ -2064,7 +1991,6 @@ def forward( if self.is_aiter_triton_fp4_bmm_enabled: from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 - decode_ql_nope = None decode_ql_nope = batched_gemm_a16wfp4( decode_q_nope, self.W_K,