From 089ebd4728399ff664ecf0065e12a8cbca7f9b42 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 15 May 2026 21:56:38 -0500 Subject: [PATCH 1/9] fix deepseek v4 high concurrency issue Signed-off-by: tjtanaa --- vllm/model_executor/layers/mhc.py | 88 ++++++++++--------- .../layers/sparse_attn_indexer.py | 27 ++---- vllm/model_executor/models/deepseek_v4.py | 3 +- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 86 ++++++++---------- 4 files changed, 94 insertions(+), 110 deletions(-) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 36325d35202c..a9f84ffaf041 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -61,31 +61,35 @@ def forward_hip( sinkhorn_repeat: int, n_splits: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - hidden_size = residual.shape[-1] - if hidden_size % 256 == 0: - return torch.ops.vllm.mhc_pre_aiter( - residual, - fn, - hc_scale, - hc_base, - rms_eps, - hc_pre_eps, - hc_sinkhorn_eps, - hc_post_mult_value, - sinkhorn_repeat, - ) - else: - return mhc_kernels.mhc_pre_torch( - residual, - fn, - hc_scale, - hc_base, - rms_eps, - hc_pre_eps, - hc_sinkhorn_eps, - hc_post_mult_value, - sinkhorn_repeat, - ) + # TODO: Reenable aiter after we are at the aiter + # version that has this bugfix + # https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649 + # It has accuracy problem at large number of tokens. + # hidden_size = residual.shape[-1] + # if hidden_size % 256 == 0: + # return torch.ops.vllm.mhc_pre_aiter( + # residual, + # fn, + # hc_scale, + # hc_base, + # rms_eps, + # hc_pre_eps, + # hc_sinkhorn_eps, + # hc_post_mult_value, + # sinkhorn_repeat, + # ) + # else: + return mhc_kernels.mhc_pre_torch( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + ) def forward_native(self, *args, **kwargs): raise NotImplementedError("Native implementation of mhc_pre is not available") @@ -124,21 +128,25 @@ def forward_hip( post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: - hidden_size = residual.shape[-1] - if hidden_size % 256 == 0: - return torch.ops.vllm.mhc_post_aiter( - x, - residual, - post_layer_mix, - comb_res_mix, - ) - else: - return mhc_kernels.mhc_post_torch( - x, - residual, - post_layer_mix, - comb_res_mix, - ) + # TODO: Reenable aiter after we are at the aiter + # version that has this bugfix + # https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649 + # It has accuracy problem at large number of tokens. + # hidden_size = residual.shape[-1] + # if hidden_size % 256 == 0: + # return torch.ops.vllm.mhc_post_aiter( + # x, + # residual, + # post_layer_mix, + # comb_res_mix, + # ) + # else: + return mhc_kernels.mhc_post_torch( + x, + residual, + post_layer_mix, + comb_res_mix, + ) def forward_native(self, *args, **kwargs): raise NotImplementedError("Native implementation of mhc_post is not available") diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 4bf52a49c43f..1859f1c6c360 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -503,27 +503,6 @@ def forward_hip( assert isinstance(q_quant, torch.Tensor), ( "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" ) - if self.skip_k_cache_insert or not rocm_aiter_ops.is_enabled(): - from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( - rocm_aiter_sparse_attn_indexer_native, - ) - - return rocm_aiter_sparse_attn_indexer_native( - hidden_states, - _encode_layer_name(self.k_cache.prefix), - self.k_cache.kv_cache, - q_quant, - k, - weights, - self.quant_block_size, - self.scale_fmt, - self.topk_tokens, - self.head_dim, - self.max_model_len, - self.max_total_seq_len, - self.topk_indices_buffer, - skip_k_cache_insert=self.skip_k_cache_insert, - ) if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, @@ -539,5 +518,9 @@ def forward_hip( self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, + skip_k_cache_insert=self.skip_k_cache_insert, ) - raise RuntimeError("Sparse attention indexer ROCm path could not be selected.") + raise RuntimeError( + "Sparse attention indexer ROCm path is only supported on AITER. " + "Please enable aiter with VLLM_ROCM_USE_AITER=1" + ) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 3f3a2fb17026..65606796c9e9 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1277,7 +1277,8 @@ def _forward_rocm( x, post, comb = self.hc_pre( x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base ) - x = self.ffn_norm(x) + # ffn_norm is now folded into self.ffn.norm_gate; ffn() takes + # the pre-norm activation directly. x = self.ffn(x, input_ids) x = self.hc_post(x, residual, post, comb) return x, None, None, None diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index e4d8eed6f771..d6d8ba69db85 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -541,7 +541,11 @@ def rocm_fp8_mqa_logits( return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) -def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor: +def _topk_indices_torch( + logits: torch.Tensor, + topk_tokens: int, + row_starts: torch.Tensor | None = None, +) -> torch.Tensor: k = min(topk_tokens, logits.shape[-1]) values, indices = torch.topk(logits, k=k, dim=-1) indices = indices.to(torch.int32) @@ -550,6 +554,12 @@ def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor: torch.full_like(indices, -1, dtype=torch.int32), indices, ) + if row_starts is not None: + # Match the CUDA top_k_per_row_prefill contract: indices are local to + # each row's valid [row_start, row_end) range, not columns in the + # concatenated chunk logits matrix. + starts = row_starts.to(dtype=torch.int32).view(-1, 1) + indices = torch.where(indices < 0, indices, indices - starts) if k == topk_tokens: return indices padded = torch.full( @@ -576,21 +586,12 @@ def rocm_aiter_sparse_attn_indexer_fake( max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, + skip_k_cache_insert: bool = False, ) -> torch.Tensor: - # profile run - # NOTE(Chen): create the max possible flattened_kv. So that - # profile_run can get correct memory usage. - device = hidden_states.device if k is None else k.device - _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=device, dtype=torch.uint8 - ) - fp8_dtype = current_platform.fp8_dtype() - _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() - _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer -def rocm_aiter_sparse_attn_indexer_native( +def rocm_aiter_sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, @@ -629,6 +630,7 @@ def rocm_aiter_sparse_attn_indexer_native( max_model_len, total_seq_lens, topk_indices_buffer, + skip_k_cache_insert, ) layer_attn_metadata = attn_metadata[k_cache_prefix] assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata) @@ -709,7 +711,19 @@ def rocm_aiter_sparse_attn_indexer_native( topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)) + + num_rows = logits.shape[0] + + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if has_decode: decode_metadata = layer_attn_metadata.decode @@ -747,7 +761,18 @@ def rocm_aiter_sparse_attn_indexer_native( ) topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)[:num_decode_tokens]) + num_rows = logits.shape[0] + + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -763,39 +788,6 @@ def rocm_aiter_sparse_attn_indexer_native( return topk_indices_buffer -def rocm_aiter_sparse_attn_indexer( - hidden_states: torch.Tensor, - k_cache_prefix: LayerNameType, - kv_cache: torch.Tensor, - q_fp8: torch.Tensor, - k: torch.Tensor, - weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, - topk_tokens: int, - head_dim: int, - max_model_len: int, - total_seq_lens: int, - topk_indices_buffer: torch.Tensor | None, -) -> torch.Tensor: - return rocm_aiter_sparse_attn_indexer_native( - hidden_states, - k_cache_prefix, - kv_cache, - q_fp8, - k, - weights, - quant_block_size, - scale_fmt, - topk_tokens, - head_dim, - max_model_len, - total_seq_lens, - topk_indices_buffer, - skip_k_cache_insert=False, - ) - - def _decode_e8m0_scales(scale: torch.Tensor) -> torch.Tensor: if scale.dtype == torch.float8_e8m0fnu: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( From 08f73d34bd5b5dadcc4fa48a144849d9f0c203f2 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 15 May 2026 22:41:14 -0500 Subject: [PATCH 2/9] num_padded_tokens Signed-off-by: tjtanaa --- vllm/v1/attention/ops/rocm_aiter_mla_sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index d69246192a82..7a8075168fe8 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -760,7 +760,7 @@ def rocm_aiter_sparse_attn_indexer( max_model_len=max_model_len, ) - topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] num_rows = logits.shape[0] torch.ops._C.top_k_per_row_decode( From bd0004f20d377328b4dc484ea4926c29fefaf0ae Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sat, 16 May 2026 15:36:09 +0000 Subject: [PATCH 3/9] fix precommit Signed-off-by: tjtanaa --- vllm/v1/attention/ops/rocm_aiter_mla_sparse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 30f9e0524401..8bda5bb7a86b 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -591,6 +591,7 @@ def rocm_aiter_sparse_attn_indexer_fake( ) -> torch.Tensor: return topk_indices_buffer + @eager_break_during_capture def rocm_aiter_sparse_attn_indexer( hidden_states: torch.Tensor, From c0219ecc5cb4af73ba359671bcbcd9a815870a95 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sat, 16 May 2026 13:17:51 -0500 Subject: [PATCH 4/9] add v1 support of fp4indexer Signed-off-by: tjtanaa --- .../test_rocm_fp8_fp4_paged_mqa_logits.py | 549 ++++++++++++++ vllm/config/attention.py | 3 +- .../layers/deepseek_v4_attention.py | 4 +- .../layers/sparse_attn_indexer.py | 45 +- vllm/utils/deep_gemm.py | 87 ++- vllm/v1/attention/backends/mla/indexer.py | 25 +- .../ops/deepseek_v4_ops/fused_indexer_q.py | 66 +- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 452 ++++++++++-- .../ops/rocm_fp8_fp4_paged_mqa_logits.py | 691 ++++++++++++++++++ 9 files changed, 1805 insertions(+), 117 deletions(-) create mode 100644 tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py create mode 100644 vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py diff --git a/tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py b/tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py new file mode 100644 index 000000000000..b0e817396368 --- /dev/null +++ b/tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py @@ -0,0 +1,549 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import os +import random + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import ( + calc_diff, + fp8_fp4_paged_mqa_logits, + get_num_sms, + get_paged_mqa_logits_metadata, +) +from vllm.utils.math_utils import cdiv + +pytestmark = pytest.mark.skipif( + not current_platform.is_rocm(), reason="ROCm Triton kernel only" +) + +MXFP4_BLOCK_SIZE = 32 +FP4_VALUES = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], + dtype=torch.float32, +) + + +@dataclasses.dataclass(frozen=True) +class PagedMQACase: + is_varlen: bool + is_fp4: bool + logits_dtype: torch.dtype + block_kv: int + batch_size: int + next_n: int + max_tokens_per_batch: int + num_heads: int + head_dim: int + avg_kv: int + + def id(self) -> str: + dtype = "bf16" if self.logits_dtype is torch.bfloat16 else "fp32" + quant = "fp4" if self.is_fp4 else "fp8" + varlen = f"varlen{self.max_tokens_per_batch}" if self.is_varlen else "dense" + return ( + f"{varlen}-{quant}-{dtype}-blk{self.block_kv}-" + f"b{self.batch_size}-n{self.next_n}-kv{self.avg_kv}" + ) + + +def _deepgemm_paged_mqa_cases() -> list[PagedMQACase]: + cases: list[PagedMQACase] = [] + for is_varlen in (True, False): + for is_fp4 in (True, False): + for logits_dtype in (torch.float32, torch.bfloat16): + for block_kv in (32, 64): + for batch_size in (256,): + next_ns = (1,) if is_varlen else (1, 2, 4, 5, 6) + for next_n in next_ns: + max_tpbs = (1, 4, 10) if is_varlen else (1,) + for max_tokens_per_batch in max_tpbs: + for num_heads, head_dim in [(32, 128), (64, 128)]: + for avg_kv in (8192, 32768): + cases.append( + PagedMQACase( + is_varlen=is_varlen, + is_fp4=is_fp4, + logits_dtype=logits_dtype, + block_kv=block_kv, + batch_size=batch_size, + next_n=next_n, + max_tokens_per_batch=( + max_tokens_per_batch + ), + num_heads=num_heads, + head_dim=head_dim, + avg_kv=avg_kv, + ) + ) + return cases + + +DEEPGEMM_PAGED_MQA_CASES = _deepgemm_paged_mqa_cases() +FULL_DEEPGEMM_SHAPES = os.getenv( + "VLLM_ROCM_PAGED_MQA_FULL_DEEPGEMM_SHAPES", "0" +) == "1" + + +def _scaled_case_dims(case: PagedMQACase) -> tuple[int, int, int]: + if FULL_DEEPGEMM_SHAPES: + return case.batch_size, case.avg_kv, 111 * 1024 + # Preserve the DeepGEMM case matrix while keeping unit tests quick enough + # for per-kernel validation. + batch_size = 4 if not case.is_varlen else 3 + avg_kv = 192 if case.avg_kv == 8192 else 448 + max_model_len = int(1.4 * avg_kv) + case.block_kv * 2 + return batch_size, avg_kv, max_model_len + + +def _quantize_to_mxfp4(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + orig_shape = x.shape + head_dim = orig_shape[-1] + assert head_dim % MXFP4_BLOCK_SIZE == 0 + n_blocks = head_dim // MXFP4_BLOCK_SIZE + + x_f32 = x.float().reshape(-1, n_blocks, MXFP4_BLOCK_SIZE) + amax = x_f32.abs().amax(dim=-1, keepdim=True).clamp(min=6 * (2**-126)) + log2_ratio = (amax / 6.0).log2().ceil().clamp(-127.0, 127.0) + scale = log2_ratio.exp2() + scales = (log2_ratio + 127.0).to(torch.uint8) + + x_scaled = (x_f32 / scale).clamp(-6.0, 6.0) + abs_x = x_scaled.abs() + code = torch.zeros_like(abs_x, dtype=torch.int32) + code = torch.where(abs_x > 0.25, 1, code) + code = torch.where(abs_x >= 0.75, 2, code) + code = torch.where(abs_x > 1.25, 3, code) + code = torch.where(abs_x >= 1.75, 4, code) + code = torch.where(abs_x > 2.5, 5, code) + code = torch.where(abs_x >= 3.5, 6, code) + code = torch.where(abs_x > 5.0, 7, code) + sign = ((x_scaled.view(torch.int32) >> 31) & 1).to(torch.uint8) + nibble = code.to(torch.uint8) | (sign << 3) + + nibble = nibble.reshape(-1, head_dim) + packed = (nibble[:, 0::2] | (nibble[:, 1::2] << 4)).contiguous() + packed = packed.reshape(*orig_shape[:-1], head_dim // 2) + scales = scales.reshape(*orig_shape[:-1], n_blocks) + return packed, scales + + +def _dequantize_mxfp4(packed: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: + table = FP4_VALUES.to(device=packed.device) + bytes_i = packed.to(torch.int16) + lo = bytes_i & 0xF + hi = (bytes_i >> 4) & 0xF + nibbles = torch.stack((lo, hi), dim=-1).flatten(-2).to(torch.long) + mag = nibbles & 0x7 + sign = (nibbles & 0x8) != 0 + values = table[mag] + values = torch.where(sign, -values, values) + scale = (scales.to(torch.float32) - 127.0).exp2() + scale = torch.repeat_interleave(scale, MXFP4_BLOCK_SIZE, dim=-1) + return values * scale + + +def _kv_cache_cast_to_fp8( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / torch.finfo(current_platform.fp8_dtype()).max + x_scaled = (x * (1.0 / sf)).to(current_platform.fp8_dtype()) + x_cast_back = x_scaled.float() * sf + + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(torch.uint8) + x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( + torch.uint8 + ) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4), ( + x_cast_back.to(x.dtype) + ) + + +def _kv_cache_cast_to_fp4( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 and head_dim == 128 + packed, scales = _quantize_to_mxfp4(x) + x_cast_back = _dequantize_mxfp4(packed, scales).view_as(x).to(x.dtype) + scales_i32 = scales.view(torch.int32).squeeze(-1) + + x_fp4 = torch.empty( + (num_blocks, block_size * (head_dim // 2 + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp4[:, : block_size * head_dim // 2] = packed.view( + num_blocks, block_size * head_dim // 2 + ) + x_fp4[:, block_size * head_dim // 2 :] = scales_i32.view(torch.uint8).view( + num_blocks, block_size * 4 + ) + return x_fp4.view(num_blocks, block_size, num_heads, head_dim // 2 + 4), ( + x_cast_back + ) + + +def _ref_paged_mqa_logits( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens_nextn: torch.Tensor, + block_table: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + batch_size, next_n, _, _ = q.size() + _, block_size, _, _ = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + context_lens_cpu = context_lens_nextn.cpu() + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = int(context_lens_cpu[batch_idx, next_idx].item()) + qx = q[batch_idx, next_idx].float() + weight = weights[row].float() + for block_rk in range(cdiv(context_len, block_size)): + block_idx = int(block_table[batch_idx, block_rk].item()) + kx = kv_cache[block_idx, :, 0].float() + offsets = torch.arange( + block_rk * block_size, + (block_rk + 1) * block_size, + device=q.device, + ) + valid = offsets < context_len + score = torch.einsum("hd,kd->hk", qx, kx) + score = torch.relu(score) * weight[:, None] + score = score.sum(dim=0) + logits[row, offsets] = torch.where(valid, score, float("-inf")) + return logits + + +def _ref_fp4_mqa_logits( + q: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + logits = torch.full( + (q.shape[0], k.shape[0]), + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + starts = cu_seqlen_ks.cpu().tolist() + ends = cu_seqlen_ke.cpu().tolist() + for row, (start, end) in enumerate(zip(starts, ends)): + if end <= start: + continue + score = torch.einsum("hd,nd->hn", q[row].float(), k[start:end].float()) + score = torch.relu(score) * weights[row].float()[:, None] + logits[row, start:end] = score.sum(dim=0) + return logits + + +@torch.inference_mode() +def test_rocm_fp4_mqa_logits_matches_reference() -> None: + from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import ( + rocm_fp4_mqa_logits, + ) + + device = torch.device("cuda") + torch.manual_seed(123) + seq_len = 7 + seq_len_kv = 19 + num_heads = 8 + head_dim = 128 + q = torch.randn( + (seq_len, num_heads, head_dim), + device=device, + dtype=torch.bfloat16, + ) * 0.125 + k = torch.randn((seq_len_kv, head_dim), device=device, dtype=torch.bfloat16) + k = k * 0.125 + weights = torch.randn((seq_len, num_heads), device=device, dtype=torch.float32) + cu_seqlen_ks = torch.tensor( + [0, 0, 2, 4, 7, 9, 12], device=device, dtype=torch.int32 + ) + cu_seqlen_ke = torch.tensor( + [3, 5, 8, 11, 13, 17, 19], device=device, dtype=torch.int32 + ) + + q_packed, q_scales = _quantize_to_mxfp4(q) + k_packed, k_scales = _quantize_to_mxfp4(k) + q_dequant = _dequantize_mxfp4(q_packed, q_scales).view_as(q).bfloat16() + k_dequant = _dequantize_mxfp4(k_packed, k_scales).view_as(k).bfloat16() + + actual = rocm_fp4_mqa_logits( + (q_packed.view(torch.int8), q_scales.view(torch.int32).squeeze(-1)), + (k_packed.view(torch.int8), k_scales.view(torch.int32).squeeze(-1)), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + expected = _ref_fp4_mqa_logits( + q_dequant, + k_dequant, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) + + finite = torch.isfinite(expected) + diff = calc_diff(actual.masked_fill(~finite, 0), expected.masked_fill(~finite, 0)) + assert diff < 2e-2, f"non-paged fp4 MQA diff={float(diff):.6f}" + torch.testing.assert_close(actual[~finite], expected[~finite]) + + +@torch.inference_mode() +def test_cp_gather_indexer_mxfp4_cache_triton_matches_compressor_layout() -> None: + from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + cp_gather_indexer_mxfp4_cache_triton, + ) + + device = torch.device("cuda") + num_blocks = 3 + block_size = 4 + head_dim = 128 + head_bytes = head_dim // 2 + scale_bytes = head_dim // MXFP4_BLOCK_SIZE + allocated_width = head_dim + 4 + k_cache = torch.zeros( + (num_blocks, block_size, allocated_width), + device=device, + dtype=torch.uint8, + ) + cache_flat = k_cache.flatten() + + value_rows = {} + scale_rows = {} + for block in range(num_blocks): + block_base = block * k_cache.stride(0) + for pos in range(block_size): + value = ( + torch.arange(head_bytes, device=device, dtype=torch.uint8) + + block * 17 + + pos * 3 + ) + scale = torch.tensor( + [127 + block, 128 + pos, 129 + block + pos, 130 - pos], + device=device, + dtype=torch.uint8, + ) + value_offset = block_base + pos * head_bytes + scale_offset = block_base + block_size * head_bytes + pos * scale_bytes + cache_flat[value_offset : value_offset + head_bytes] = value + cache_flat[scale_offset : scale_offset + scale_bytes] = scale + value_rows[(block, pos)] = value + scale_rows[(block, pos)] = scale + + block_table = torch.tensor([[2, 0], [1, 0]], device=device, dtype=torch.int32) + cu_seqlen = torch.tensor([0, 5, 8], device=device, dtype=torch.int32) + token_to_seq = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1], device=device) + token_to_seq = token_to_seq.to(torch.int32) + k_fp4 = torch.empty((8, head_bytes), device=device, dtype=torch.uint8) + k_scale = torch.empty((8, scale_bytes), device=device, dtype=torch.uint8) + + cp_gather_indexer_mxfp4_cache_triton( + k_cache, + k_fp4, + k_scale, + block_table, + cu_seqlen, + token_to_seq, + ) + + expected_values = [] + expected_scales = [] + for token in range(8): + req = int(token_to_seq[token].item()) + local = token - int(cu_seqlen[req].item()) + block = int(block_table[req, local // block_size].item()) + pos = local % block_size + expected_values.append(value_rows[(block, pos)]) + expected_scales.append(scale_rows[(block, pos)]) + torch.testing.assert_close(k_fp4, torch.stack(expected_values)) + torch.testing.assert_close(k_scale, torch.stack(expected_scales)) + + +@pytest.mark.parametrize( + "case", + DEEPGEMM_PAGED_MQA_CASES, + ids=[case.id() for case in DEEPGEMM_PAGED_MQA_CASES], +) +@torch.inference_mode() +def test_rocm_fp8_fp4_paged_mqa_logits_deepgemm_cases(case: PagedMQACase) -> None: + device = torch.device("cuda") + seed = DEEPGEMM_PAGED_MQA_CASES.index(case) + torch.manual_seed(seed) + random.seed(seed) + + raw_batch_size, avg_kv, max_model_len = _scaled_case_dims(case) + raw_next_n = case.next_n + if case.is_varlen: + tokens_per_seq = torch.randint( + 1, + case.max_tokens_per_batch + 1, + (raw_batch_size,), + device=device, + dtype=torch.int32, + ) + indices = torch.arange(raw_batch_size, device=device, dtype=torch.int32) + indices = indices.repeat_interleave(tokens_per_seq) + batch_size = int(tokens_per_seq.sum().item()) + next_n = 1 + else: + tokens_per_seq = None + indices = None + batch_size = raw_batch_size + next_n = raw_next_n + + q = torch.randn( + (batch_size, next_n, case.num_heads, case.head_dim), + device=device, + dtype=torch.bfloat16, + ) + q = q * 0.125 + weights = torch.randn( + (batch_size * next_n, case.num_heads), + device=device, + dtype=torch.float32, + ) + + low = max(1, int(0.7 * avg_kv)) + high = max(low + 1, int(1.3 * avg_kv)) + context_lens = torch.randint( + low, + high, + (raw_batch_size,), + device=device, + dtype=torch.int32, + ) + context_lens.clamp_(max=max_model_len) + if case.is_varlen: + assert tokens_per_seq is not None + max_ctx_len_per_seq = (context_lens + tokens_per_seq - 1).clamp( + max=max_model_len + ) + else: + max_ctx_len_per_seq = context_lens + + num_blocks_per_query = torch.ceil( + max_ctx_len_per_seq.float() / case.block_kv + ).to(torch.int32) + total_used_blocks = int(num_blocks_per_query.sum().item()) + num_total_blocks = total_used_blocks + 8 + kv_cache = torch.randn( + (num_total_blocks, case.block_kv, 1, case.head_dim), + device=device, + dtype=torch.bfloat16, + ) + kv_cache = kv_cache * 0.125 + + block_table = torch.empty( + (raw_batch_size, int(num_blocks_per_query.max().item())), + device=device, + dtype=torch.int32, + ) + block_idx_pool = torch.randperm(num_total_blocks, device=device, dtype=torch.int32) + offset = 0 + for i, num_blocks in enumerate(num_blocks_per_query.tolist()): + block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks] + offset += num_blocks + + if case.is_varlen: + assert tokens_per_seq is not None + context_lens = context_lens.repeat_interleave(tokens_per_seq) + offsets_within_seq = torch.cat( + [ + torch.arange(int(n.item()), device=device, dtype=torch.int32) + for n in tokens_per_seq + ] + ) + context_lens = (context_lens + offsets_within_seq).clamp(max=max_model_len) + block_table = block_table.repeat_interleave(tokens_per_seq, dim=0) + + if case.is_varlen: + context_lens_nextn = context_lens.view(-1, 1) + else: + rand = torch.rand(batch_size, next_n, device=device) + context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * rand).int() + context_lens_nextn[:, -1] = context_lens + context_lens_nextn.clamp_(min=1, max=max_model_len) + context_lens_nextn = context_lens_nextn.contiguous().to(torch.int32) + + if case.is_fp4: + q_packed, q_scales_u8 = _quantize_to_mxfp4(q) + q_in = ( + q_packed.view(batch_size, next_n, case.num_heads, case.head_dim // 2).view( + torch.int8 + ), + q_scales_u8.view(torch.int32).squeeze(-1), + ) + q_simulated = _dequantize_mxfp4(q_packed, q_scales_u8).view_as(q).to( + torch.bfloat16 + ) + kv_in, kv_simulated = _kv_cache_cast_to_fp4(kv_cache) + else: + q_in = (q.to(current_platform.fp8_dtype()), None) + q_simulated = q_in[0].to(torch.bfloat16) + kv_in, kv_simulated = _kv_cache_cast_to_fp8(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens_nextn, + case.block_kv, + get_num_sms(), + indices=indices, + ) + assert schedule_metadata.shape == (get_num_sms() + 1, 2) + assert schedule_metadata.dtype == torch.int32 + + logits = fp8_fp4_paged_mqa_logits( + q_in, + kv_in, + weights, + context_lens_nextn, + block_table, + schedule_metadata, + max_model_len, + clean_logits=False, + logits_dtype=case.logits_dtype, + indices=indices, + ) + assert logits.dtype == case.logits_dtype + + ref_logits = _ref_paged_mqa_logits( + q_simulated, + kv_simulated, + weights, + context_lens_nextn, + block_table, + max_model_len, + ) + positions = torch.arange(max_model_len, device=device).unsqueeze(0) + invalid = positions >= context_lens_nextn.view(-1, 1) + actual = logits.float().masked_fill(invalid, 0) + expected = ref_logits.masked_fill(invalid, 0) + + diff = calc_diff(actual, expected) + tolerance = 2e-2 if case.is_fp4 or case.logits_dtype is torch.bfloat16 else 1e-3 + assert diff < tolerance, f"{case.id()} diff={float(diff):.6f}" diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 52ce9f102a6c..c446f92c56f0 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -48,7 +48,8 @@ class AttentionConfig: """If set, quantize query for attention in prefill.""" use_fp4_indexer_cache: bool = False - """If set, use fp4 indexer cache for dsv32 family model (not support yet)""" + """If set, use fp4 indexer cache for dsv32 family models. + Supported on CUDA SM100 datacenter GPUs and ROCm gfx95x GPUs.""" use_non_causal: bool = False """Whether to use non-causal (bidirectional) attention.""" diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 996534b05f13..32d7907cc7fe 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -1130,8 +1130,8 @@ def __init__( assert cache_config is not None, "Deepseek V4 indexer requires cache_config" # NOTE(yifan): FP8 indxer cache use the same layout as V3.2: # head_dim bytes = 128 fp8 + 4 fp32 scale = 132. - # For FP4 indexer cache, we still allocate the same amount of memory as FP8, - # but only use the first half of the memory. + # For FP4 indexer cache, keep the same allocation width as CUDA and + # only use the MXFP4 value/scale prefix. k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4 self.k_cache = DeepseekV4IndexerCache( head_dim=k_cache_head_dim, diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 381e877fcb99..dba8db9d9bff 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -501,10 +501,41 @@ def forward_hip( k: torch.Tensor, weights: torch.Tensor, ): - assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet" - assert isinstance(q_quant, torch.Tensor), ( - "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" - ) + if self.use_fp4_cache: + assert isinstance(q_quant, tuple), ( + "AMD MXFP4 sparse_attn_indexer expects (q_values, q_scales)" + ) + else: + assert isinstance(q_quant, torch.Tensor), ( + "AMD FP8 sparse_attn_indexer expects a single q_quant tensor" + ) + + if ( + self.use_fp4_cache + or self.skip_k_cache_insert + or not rocm_aiter_ops.is_enabled() + ): + from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + rocm_aiter_sparse_attn_indexer_native, + ) + + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + _encode_layer_name(self.k_cache.prefix), + self.k_cache.kv_cache, + q_quant, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + skip_k_cache_insert=self.skip_k_cache_insert, + use_fp4_cache=self.use_fp4_cache, + ) if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, @@ -520,9 +551,5 @@ def forward_hip( self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, - skip_k_cache_insert=self.skip_k_cache_insert, ) - raise RuntimeError( - "Sparse attention indexer ROCm path is only supported on AITER. " - "Please enable aiter with VLLM_ROCM_USE_AITER=1" - ) + raise RuntimeError("Sparse attention indexer ROCm path could not be selected.") diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..ffb03ae2303e 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -244,6 +244,8 @@ def _lazy_init() -> None: def get_num_sms() -> int: + if current_platform.is_rocm(): + return int(current_platform.num_compute_units()) _lazy_init() dg = _import_deep_gemm() if dg is None: @@ -370,6 +372,20 @@ def fp8_fp4_mqa_logits( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ + if current_platform.is_rocm() and q[1] is not None: + from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import ( + rocm_fp4_mqa_logits, + ) + + return rocm_fp4_mqa_logits( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=clean_logits, + ) + _lazy_init() if _fp8_fp4_mqa_logits_impl is None: return _missing() @@ -384,24 +400,45 @@ def fp8_fp4_mqa_logits( def get_paged_mqa_logits_metadata( - context_lens: torch.Tensor, block_size: int, num_sms: int + context_lens: torch.Tensor, + block_size: int, + num_sms: int, + indices: torch.Tensor | None = None, ) -> torch.Tensor: """Build scheduling metadata for paged MQA logits. Args: - context_lens: Tensor of shape [B], dtype int32; effective context length - per batch element. + context_lens: Tensor of shape [B] or [B, next_n], dtype int32; + effective context length per batch element or per decoded token. block_size: KV-cache block size in tokens (e.g., 64). - num_sms: Number of SMs available. 132 for Hopper + num_sms: Number of SMs/CUs available. + indices: Optional varlen token-to-sequence indices for DeepGEMM SM100 + style scheduling. ROCm accepts this for API compatibility. Returns: Backend-specific tensor consumed by `fp8_fp4_paged_mqa_logits` to schedule work across SMs. """ + if current_platform.is_rocm(): + from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import ( + rocm_get_paged_mqa_logits_metadata, + ) + + return rocm_get_paged_mqa_logits_metadata( + context_lens, + block_size, + num_sms, + indices=indices, + ) + _lazy_init() if _get_paged_mqa_logits_metadata_impl is None: return _missing() - return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) + if indices is None: + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) + return _get_paged_mqa_logits_metadata_impl( + context_lens, block_size, num_sms, indices=indices + ) def fp8_fp4_paged_mqa_logits( @@ -413,6 +450,8 @@ def fp8_fp4_paged_mqa_logits( schedule_metadata: torch.Tensor, max_model_len: int, clean_logits: bool, + logits_dtype: torch.dtype = torch.float32, + indices: torch.Tensor | None = None, ) -> torch.Tensor: """Compute MQA logits using a paged KV-cache. @@ -426,25 +465,51 @@ def fp8_fp4_paged_mqa_logits( q_values is packed uint8 and q_scale is the companion block-scale tensor. kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, - D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos) - storing the float dequant scale. + D+4], dtype `torch.uint8`, with one float32 scale per token. + FP4 layout is [num_blocks, block_size, 1, D/2+4], with packed + MXFP4 values and one packed int32 UE8M0 scale word per token. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. - context_lens: Tensor of shape [B], dtype int32; effective context length - for each batch element. + context_lens: Tensor of shape [B] or [B, next_n], dtype int32; + effective context length for each batch element/token. block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache. schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; used to distribute work across SMs. max_model_len: Maximum sequence length used to size the logits output. clean_logits: Whether to clean the unfilled logits into `-inf`. + logits_dtype: Output dtype, matching DeepGEMM's float32/bfloat16 API. + indices: Optional varlen token-to-sequence indices. Returns: Logits tensor of shape [B * next_n, max_model_len], dtype - `torch.float32`. + `logits_dtype`. """ + if current_platform.is_rocm(): + from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import ( + rocm_fp8_fp4_paged_mqa_logits, + ) + + return rocm_fp8_fp4_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=clean_logits, + logits_dtype=logits_dtype, + indices=indices, + ) + _lazy_init() if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() + kwargs: dict[str, Any] = {"clean_logits": clean_logits} + if logits_dtype is not torch.float32: + kwargs["logits_dtype"] = logits_dtype + if indices is not None: + kwargs["indices"] = indices return _fp8_fp4_paged_mqa_logits_impl( q, kv_cache, @@ -453,7 +518,7 @@ def fp8_fp4_paged_mqa_logits( block_tables, schedule_metadata, max_model_len, - clean_logits=clean_logits, + **kwargs, ) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 2870ec9a15c0..e58a2cde7bb2 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -256,25 +256,32 @@ def __init__(self, *args, **kwargs): self.vllm_config.attention_config.use_fp4_indexer_cache ) - assert ( - current_platform.is_device_capability_family(100) - or not self.use_fp4_indexer_cache - ), ( + supports_fp4_indexer_cache = current_platform.is_device_capability_family( + 100 + ) or (current_platform.is_rocm() and current_platform.supports_mx()) + assert supports_fp4_indexer_cache or not self.use_fp4_indexer_cache, ( "use_fp4_indexer_cache requires Blackwell datacenter GPUs " - "(sm_10x, e.g. B200/GB200); sm_120 (consumer Blackwell) and " - "earlier architectures are not supported." + "(sm_10x, e.g. B200/GB200) or ROCm GPUs with MX support " + "(gfx95x, e.g. MI355X); sm_120, earlier CUDA architectures, " + "and non-MX ROCm architectures are not supported." ) next_n = self.num_speculative_tokens + 1 self.reorder_batch_threshold += self.num_speculative_tokens + native_paged_mqa_logits = current_platform.is_device_capability_family( + 100 + ) or ( + current_platform.is_rocm() + and current_platform.supports_mx() + and self.use_fp4_indexer_cache + ) # NOTE(zyongye) fp4 indexer cache only natively supports next_n in # natively_supported_next_n_fp4; for other next_n values we fall back - # to the flattening path. Outside the SM100 datacenter family the FP8 + # to the flattening path. Outside native FP4 indexer platforms the FP8 # paged MQA logits kernel has the same [1, 2] constraint (deepgemm # smxx_fp8_fp4_paged_mqa_logits.hpp:233), so flatten there too. self.use_flattening = ( - self.use_fp4_indexer_cache - or not current_platform.is_device_capability_family(100) + self.use_fp4_indexer_cache or not native_paged_mqa_logits ) and next_n not in self.natively_supported_next_n_fp4 sm_count = num_compute_units(self.device.index) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index ec880f7ab4c4..ad2ee75476ce 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -3,6 +3,7 @@ import torch +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_cutedsl @@ -25,23 +26,54 @@ def _get_cos_sin( return cos, sin -@triton.jit -def _fp32x2_to_fp4x2(x_lo, x_hi): - # NOTE: $1 is high nibble, $2 is low nibble - return tl.inline_asm_elementwise( - """ - { - .reg .b8 tmp; - cvt.rn.satfinite.e2m1x2.f32 tmp, $1, $2; - cvt.u32.u8 $0, tmp; - } - """, - constraints="=r,f,f", - args=[x_hi, x_lo], - dtype=tl.uint32, - is_pure=True, - pack=1, - ).to(tl.uint8) +if current_platform.is_rocm(): + + @triton.jit + def _fp32x2_to_fp4x2(x_lo, x_hi): + lo_abs = tl.abs(x_lo) + hi_abs = tl.abs(x_hi) + + lo = tl.full(x_lo.shape, 0, tl.uint32) + lo = tl.where(lo_abs > 0.25, 1, lo) + lo = tl.where(lo_abs >= 0.75, 2, lo) + lo = tl.where(lo_abs > 1.25, 3, lo) + lo = tl.where(lo_abs >= 1.75, 4, lo) + lo = tl.where(lo_abs > 2.5, 5, lo) + lo = tl.where(lo_abs >= 3.5, 6, lo) + lo = tl.where(lo_abs > 5.0, 7, lo) + lo = lo | ((x_lo < 0.0).to(tl.uint32) << 3) + + hi = tl.full(x_hi.shape, 0, tl.uint32) + hi = tl.where(hi_abs > 0.25, 1, hi) + hi = tl.where(hi_abs >= 0.75, 2, hi) + hi = tl.where(hi_abs > 1.25, 3, hi) + hi = tl.where(hi_abs >= 1.75, 4, hi) + hi = tl.where(hi_abs > 2.5, 5, hi) + hi = tl.where(hi_abs >= 3.5, 6, hi) + hi = tl.where(hi_abs > 5.0, 7, hi) + hi = hi | ((x_hi < 0.0).to(tl.uint32) << 3) + + return (lo | (hi << 4)).to(tl.uint8) + +else: + + @triton.jit + def _fp32x2_to_fp4x2(x_lo, x_hi): + # NOTE: $1 is high nibble, $2 is low nibble + return tl.inline_asm_elementwise( + """ + { + .reg .b8 tmp; + cvt.rn.satfinite.e2m1x2.f32 tmp, $1, $2; + cvt.u32.u8 $0, tmp; + } + """, + constraints="=r,f,f", + args=[x_hi, x_lo], + dtype=tl.uint32, + is_pure=True, + pack=1, + ).to(tl.uint8) @triton.jit diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 8bda5bb7a86b..892095c0cd93 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -12,9 +12,13 @@ from vllm.forward_context import get_forward_context from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import fp8_fp4_paged_mqa_logits from vllm.utils.torch_utils import LayerNameType from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import ( + rocm_fp4_mqa_logits, +) if current_platform.is_rocm(): from vllm.platforms.rocm import _ON_GFX942, _ON_GFX950 @@ -22,6 +26,8 @@ _ON_GFX942 = False _ON_GFX950 = False +MXFP4_BLOCK_SIZE = 32 + @triton.jit def _indexer_k_quant_and_cache_kernel( @@ -226,6 +232,75 @@ def cp_gather_indexer_k_quant_cache_triton( ) +@triton.jit +def _cp_gather_indexer_mxfp4_cache_kernel( + kv_cache_ptr, + k_fp4_ptr, + k_scale_ptr, + block_table_ptr, + cu_seqlen_ptr, + token_to_seq_ptr, + block_size, + block_table_stride, + kv_cache_stride, + HEAD_BYTES: tl.constexpr, + SCALE_BYTES: tl.constexpr, + BLOCK_HEAD: tl.constexpr, +): + tid = tl.program_id(0) + offsets = tl.arange(0, BLOCK_HEAD) + batch_id = tl.load(token_to_seq_ptr + tid) + batch_start = tl.load(cu_seqlen_ptr + batch_id) + batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) + if tid >= batch_end: + return + + batch_offset = tid - batch_start + block_table_id = batch_offset // block_size + block_offset = batch_offset % block_size + block_id = tl.load(block_table_ptr + batch_id * block_table_stride + block_table_id) + block_base = kv_cache_ptr + block_id.to(tl.int64) * kv_cache_stride + + src_values = block_base + block_offset * HEAD_BYTES + dst_values = k_fp4_ptr + tid * HEAD_BYTES + value_mask = offsets < HEAD_BYTES + values = tl.load(src_values + offsets, mask=value_mask, other=0) + tl.store(dst_values + offsets, values, mask=value_mask) + + scale_offsets = tl.arange(0, SCALE_BYTES) + src_scales = block_base + block_size * HEAD_BYTES + block_offset * SCALE_BYTES + dst_scales = k_scale_ptr + tid * SCALE_BYTES + scales = tl.load(src_scales + scale_offsets) + tl.store(dst_scales + scale_offsets, scales) + + +def cp_gather_indexer_mxfp4_cache_triton( + k_cache: torch.Tensor, + k_fp4: torch.Tensor, + k_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seqlen: torch.Tensor, + token_to_seq: torch.Tensor, +): + head_bytes = k_fp4.shape[-1] + scale_bytes = k_scale.shape[-1] + num_tokens = k_fp4.size(0) + _cp_gather_indexer_mxfp4_cache_kernel[(num_tokens,)]( + k_cache, + k_fp4, + k_scale, + block_table, + cu_seqlen, + token_to_seq, + k_cache.shape[1], + block_table.stride(0), + k_cache.stride(0), + HEAD_BYTES=head_bytes, + SCALE_BYTES=scale_bytes, + BLOCK_HEAD=triton.next_power_of_2(head_bytes), + ) + + # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 def fp8_paged_mqa_logits_torch( q: torch.Tensor, @@ -573,11 +648,73 @@ def _topk_indices_torch( return padded +def _has_vllm_topk_ops() -> bool: + return ( + hasattr(torch.ops, "_C") + and hasattr(torch.ops._C, "top_k_per_row_prefill") + and hasattr(torch.ops._C, "top_k_per_row_decode") + ) + + +def _topk_indices_vllm_prefill( + logits: torch.Tensor, + topk_tokens: int, + row_starts: torch.Tensor, + row_ends: torch.Tensor, +) -> torch.Tensor: + if not _has_vllm_topk_ops(): + return _topk_indices_torch(logits, topk_tokens, row_starts) + + indices = torch.empty( + (logits.shape[0], topk_tokens), + dtype=torch.int32, + device=logits.device, + ) + torch.ops._C.top_k_per_row_prefill( + logits, + row_starts, + row_ends, + indices, + logits.shape[0], + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + return indices + + +def _topk_indices_vllm_decode( + logits: torch.Tensor, + topk_tokens: int, + next_n: int, + seq_lens: torch.Tensor, +) -> torch.Tensor: + if not _has_vllm_topk_ops(): + return _topk_indices_torch(logits, topk_tokens) + + indices = torch.empty( + (logits.shape[0], topk_tokens), + dtype=torch.int32, + device=logits.device, + ) + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + indices, + logits.shape[0], + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + return indices + + def rocm_aiter_sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, - q_fp8: torch.Tensor, + q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor], k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, @@ -587,17 +724,42 @@ def rocm_aiter_sparse_attn_indexer_fake( max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, - skip_k_cache_insert: bool = False, + use_fp4_cache: bool = False, ) -> torch.Tensor: + del ( + k_cache_prefix, + kv_cache, + q_quant, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + max_model_len, + ) + device = hidden_states.device if k is None else k.device + flattened_width = ( + head_dim // 2 + head_dim // MXFP4_BLOCK_SIZE + if use_fp4_cache + else head_dim + 4 + ) + _flattened_kv = torch.empty( + [total_seq_lens, flattened_width], device=device, dtype=torch.uint8 + ) + if use_fp4_cache: + _k_fp4 = _flattened_kv[..., : head_dim // 2].contiguous() + _k_scale = _flattened_kv[..., head_dim // 2 :].contiguous() + else: + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer -@eager_break_during_capture -def rocm_aiter_sparse_attn_indexer( +def rocm_aiter_sparse_attn_indexer_native( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, - q_fp8: torch.Tensor, + q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor], k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, @@ -608,6 +770,7 @@ def rocm_aiter_sparse_attn_indexer( total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, skip_k_cache_insert: bool = False, + use_fp4_cache: bool = False, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata @@ -622,7 +785,7 @@ def rocm_aiter_sparse_attn_indexer( hidden_states, k_cache_prefix, kv_cache, - q_fp8, + q_quant, k, weights, quant_block_size, @@ -632,7 +795,7 @@ def rocm_aiter_sparse_attn_indexer( max_model_len, total_seq_lens, topk_indices_buffer, - skip_k_cache_insert, + use_fp4_cache=use_fp4_cache, ) layer_attn_metadata = attn_metadata[k_cache_prefix] assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata) @@ -643,6 +806,17 @@ def rocm_aiter_sparse_attn_indexer( has_prefill = layer_attn_metadata.num_prefills > 0 num_decode_tokens = layer_attn_metadata.num_decode_tokens device = hidden_states.device if k is None else k.device + if use_fp4_cache: + assert isinstance(q_quant, tuple), ( + "MXFP4 sparse_attn_indexer expects (q_values, q_scales)" + ) + q_values, q_scale = q_quant + assert q_scale.dtype == torch.int32 + else: + assert isinstance(q_quant, torch.Tensor), ( + "FP8 sparse_attn_indexer expects a single q tensor" + ) + q_values, q_scale = q_quant, None # during speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. @@ -653,6 +827,7 @@ def rocm_aiter_sparse_attn_indexer( raise ValueError("k must be provided when skip_k_cache_insert is False") if not skip_k_cache_insert: + assert not use_fp4_cache, "Unfused FP4 indexer cache insert is unsupported" if _ON_GFX942: ops.indexer_k_quant_and_cache( k, @@ -674,21 +849,48 @@ def rocm_aiter_sparse_attn_indexer( if has_prefill: prefill_metadata = layer_attn_metadata.prefill assert prefill_metadata is not None - for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], + max_total_seq_lens = max( + (chunk.total_seq_lens for chunk in prefill_metadata.chunks), + default=0, + ) + if use_fp4_cache: + k_quant_full = torch.zeros( + [max_total_seq_lens, head_dim // 2], + device=device, + dtype=torch.uint8, + ) + k_scale_full = torch.zeros( + [max_total_seq_lens, head_dim // MXFP4_BLOCK_SIZE], + device=device, + dtype=torch.uint8, + ) + else: + k_quant_full = torch.zeros( + [max_total_seq_lens, head_dim], device=device, dtype=fp8_dtype, ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], + k_scale_full = torch.zeros( + [max_total_seq_lens, 4], device=device, dtype=torch.uint8, ) - if _ON_GFX942: + for chunk in prefill_metadata.chunks: + k_quant = k_quant_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] + if use_fp4_cache: + cp_gather_indexer_mxfp4_cache_triton( + kv_cache, + k_quant, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + token_to_seq=chunk.token_to_seq, + ) + elif _ON_GFX942: ops.cp_gather_indexer_k_quant_cache( kv_cache, - k_fp8, + k_quant, k_scale, chunk.block_table, chunk.cu_seq_lens, @@ -696,91 +898,170 @@ def rocm_aiter_sparse_attn_indexer( else: cp_gather_indexer_k_quant_cache_triton( kv_cache, - k_fp8, + k_quant, k_scale, chunk.block_table, chunk.cu_seq_lens, token_to_seq=chunk.token_to_seq, ) - logits = rocm_fp8_mqa_logits( - q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32)), - weights[chunk.token_start : chunk.token_end], - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - ) + q_slice = q_values[chunk.token_start : chunk.token_end] + weight_slice = weights[chunk.token_start : chunk.token_end] + if use_fp4_cache: + assert q_scale is not None + use_vllm_topk = _has_vllm_topk_ops() + q_scale_slice = q_scale[chunk.token_start : chunk.token_end] + logits = rocm_fp4_mqa_logits( + (q_slice.view(torch.int8), q_scale_slice), + ( + k_quant.view(torch.int8), + k_scale.view(torch.int32).squeeze(-1), + ), + weight_slice, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + clean_logits=not use_vllm_topk, + ) + else: + logits = rocm_fp8_mqa_logits( + q_slice, + (k_quant, k_scale.view(torch.float32)), + weight_slice, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - - num_rows = logits.shape[0] - - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + if use_fp4_cache: + topk_indices.copy_( + _topk_indices_vllm_prefill( + logits, + topk_tokens, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + ) + else: + topk_indices.copy_( + _topk_indices_torch(logits, topk_tokens, chunk.cu_seqlen_ks) + ) if has_decode: decode_metadata = layer_attn_metadata.decode assert decode_metadata is not None - # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], - kv_cache = kv_cache.unsqueeze(-2) + if use_fp4_cache: + num_blocks, block_size, _ = kv_cache.shape + page_bytes = int(kv_cache.stride(0)) + fp4_bytes = head_dim // 2 + head_dim // MXFP4_BLOCK_SIZE + kv_cache_decode = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, fp4_bytes), + stride=(page_bytes, fp4_bytes, fp4_bytes, 1), + ) + else: + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache_decode = kv_cache.unsqueeze(-2) decode_lens = decode_metadata.decode_lens if decode_metadata.requires_padding: # pad in edge case where we have short chunked prefill length < # decode_threshold since we unstrictly split # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) - padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens - ) + if use_fp4_cache: + padded_q_decode_tokens = pack_seq_triton( + q_values[:num_decode_tokens], + decode_lens, + pad_value=0, + ) + assert q_scale is not None + q_scale_bytes = q_scale[:num_decode_tokens].contiguous() + q_scale_bytes = q_scale_bytes.view(torch.uint8).reshape( + num_decode_tokens, -1 + ) + padded_q_scale = pack_seq_triton( + q_scale_bytes, + decode_lens, + pad_value=0, + ) + padded_q_scale = padded_q_scale.view(torch.int32).reshape( + padded_q_scale.shape[0], + padded_q_scale.shape[1], + -1, + ) + else: + padded_q_decode_tokens = pack_seq_triton( + q_values[:num_decode_tokens], decode_lens + ) + padded_q_scale = None else: - padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:] + padded_q_decode_tokens = q_values[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_values.shape[1:] ) + if use_fp4_cache: + assert q_scale is not None + padded_q_scale = q_scale[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_scale.shape[1:] + ) + else: + padded_q_scale = None # TODO: move and optimize below logic with triton kernels - batch_size = padded_q_fp8_decode_tokens.shape[0] - next_n = padded_q_fp8_decode_tokens.shape[1] + batch_size = padded_q_decode_tokens.shape[0] + next_n = padded_q_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n - logits = rocm_fp8_paged_mqa_logits( - padded_q_fp8_decode_tokens, - kv_cache, - weights[:num_padded_tokens], - decode_metadata.seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - ) + if use_fp4_cache: + use_vllm_topk = _has_vllm_topk_ops() + active_paged_width = ( + decode_metadata.block_table.shape[1] * kv_cache_decode.shape[1] + ) + logits_width = min( + max_model_len, + max(topk_tokens, active_paged_width), + ) + logits = fp8_fp4_paged_mqa_logits( + (padded_q_decode_tokens.view(torch.int8), padded_q_scale), + kv_cache_decode, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=logits_width, + clean_logits=not use_vllm_topk, + ) + else: + logits = rocm_fp8_paged_mqa_logits( + padded_q_decode_tokens, + kv_cache_decode, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) - topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - num_rows = logits.shape[0] - - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + if use_fp4_cache: + topk_indices.copy_( + _topk_indices_vllm_decode( + logits, + topk_tokens, + next_n, + decode_metadata.seq_lens, + )[:num_decode_tokens] + ) + else: + topk_indices.copy_( + _topk_indices_torch(logits, topk_tokens)[:num_decode_tokens] + ) if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( - topk_indices.reshape(batch_size, next_n, topk_indices.shape[-1]), + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens, ) topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( @@ -790,6 +1071,41 @@ def rocm_aiter_sparse_attn_indexer( return topk_indices_buffer +@eager_break_during_capture +def rocm_aiter_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + skip_k_cache_insert=False, + use_fp4_cache=False, + ) + + def _decode_e8m0_scales(scale: torch.Tensor) -> torch.Tensor: if scale.dtype == torch.float8_e8m0fnu: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( diff --git a/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py b/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py new file mode 100644 index 000000000000..b5e60f488501 --- /dev/null +++ b/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ROCm Triton implementation of DeepGEMM's paged MQA logits API. + +The public wrapper mirrors ``deep_gemm.fp8_fp4_paged_mqa_logits`` for the +layouts used by DeepSeek V4's sparse indexer: + +* FP8 Q + FP8 K-cache: K values are stored first for the whole page, followed + by one float32 scale per token. +* MXFP4 Q + MXFP4 K-cache: packed E2M1 values are stored first for the whole + page, followed by one int32 per token containing four UE8M0 scale bytes. + +DeepGEMM's CUDA scheduler metadata is accepted for API compatibility but is not +needed by this Triton implementation. +""" + +from __future__ import annotations + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + + +@triton.jit +def _fp4_e2m1_to_f32(nibble): + mag = nibble & 0x7 + val = tl.where( + mag == 0, + 0.0, + tl.where( + mag == 1, + 0.5, + tl.where( + mag == 2, + 1.0, + tl.where( + mag == 3, + 1.5, + tl.where( + mag == 4, + 2.0, + tl.where(mag == 5, 3.0, tl.where(mag == 6, 4.0, 6.0)), + ), + ), + ), + ), + ) + return tl.where((nibble & 0x8) != 0, -val, val) + + +@triton.jit +def _ue8m0_scale_from_packed_i32(scale_word, block_idx): + shift = block_idx * 8 + encoded = (scale_word.to(tl.uint32) >> shift) & 0xFF + return tl.exp2(encoded.to(tl.float32) - 127.0) + + +@triton.jit +def _ue8m0_scale_byte_from_packed_i32(scale_word, block_idx): + shift = block_idx * 8 + return ((scale_word.to(tl.uint32) >> shift) & 0xFF).to(tl.uint8) + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + q_ptr, + kv_cache_ptr, + weights_ptr, + context_lens_ptr, + block_table_ptr, + logits_ptr, + q_stride_b, + q_stride_n, + q_stride_h, + q_stride_d, + weights_stride_row, + weights_stride_h, + context_stride_b, + context_stride_n, + block_table_stride_b, + logits_stride_row, + max_context_len, + KV_CACHE_STRIDE_B: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + NEXT_N: tl.constexpr, + NUM_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_KV: tl.constexpr, + IS_CONTEXT_LENS_2D: tl.constexpr, + IS_FNUZ: tl.constexpr, +): + row = tl.program_id(0) + logical_block = tl.program_id(1) + + batch = row // NEXT_N + next_idx = row - batch * NEXT_N + if IS_CONTEXT_LENS_2D: + context_len = tl.load( + context_lens_ptr + batch * context_stride_b + next_idx * context_stride_n + ) + valid_limit = context_len + else: + context_len = tl.load(context_lens_ptr + batch * context_stride_b) + valid_limit = context_len - NEXT_N + next_idx + 1 + valid_limit = tl.minimum(valid_limit, max_context_len) + + tile_start = logical_block * KV_BLOCK_SIZE + if tile_start >= valid_limit: + return + + physical_block = tl.load( + block_table_ptr + batch * block_table_stride_b + logical_block + ) + block_base = kv_cache_ptr + physical_block.to(tl.int64) * KV_CACHE_STRIDE_B + scale_base = (block_base + KV_BLOCK_SIZE * HEAD_DIM).to( + tl.pointer_type(tl.float32) + ) + + h = tl.arange(0, NUM_HEADS) + d = tl.arange(0, HEAD_DIM) + k_offsets = tl.arange(0, BLOCK_KV) + token_pos = k_offsets + global_k = tile_start + k_offsets + valid_k = (k_offsets < KV_BLOCK_SIZE) & (global_k < valid_limit) + + q = tl.load( + q_ptr + + batch * q_stride_b + + next_idx * q_stride_n + + h[:, None] * q_stride_h + + d[None, :] * q_stride_d, + cache_modifier=".cg", + ) + + k_u8 = tl.load( + block_base + token_pos[None, :] * HEAD_DIM + d[:, None], + mask=(k_offsets[None, :] < KV_BLOCK_SIZE), + other=0, + ) + if IS_FNUZ: + k = k_u8.to(tl.float8e4b15, bitcast=True) + else: + k = k_u8.to(tl.float8e4nv, bitcast=True) + + scores = tl.dot(q, k, input_precision="ieee") + k_scales = tl.load(scale_base + token_pos, mask=valid_k, other=0.0) + scores = scores * k_scales[None, :] + scores = tl.maximum(scores, 0.0) + + weights = tl.load( + weights_ptr + row * weights_stride_row + h * weights_stride_h, + cache_modifier=".cg", + ).to(tl.float32) + scores = scores * weights[:, None] + logits = tl.sum(scores, axis=0) + tl.store( + logits_ptr + row * logits_stride_row + global_k, + logits, + mask=valid_k, + ) + + +@triton.jit +def _fp4_paged_mqa_logits_kernel( + q_ptr, + q_scale_ptr, + kv_cache_ptr, + weights_ptr, + context_lens_ptr, + block_table_ptr, + logits_ptr, + q_stride_b, + q_stride_n, + q_stride_h, + q_stride_d, + q_scale_stride_b, + q_scale_stride_n, + q_scale_stride_h, + weights_stride_row, + weights_stride_h, + context_stride_b, + context_stride_n, + block_table_stride_b, + logits_stride_row, + max_context_len, + KV_CACHE_STRIDE_B: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + NEXT_N: tl.constexpr, + NUM_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_KV: tl.constexpr, + IS_CONTEXT_LENS_2D: tl.constexpr, +): + row = tl.program_id(0) + logical_block = tl.program_id(1) + + batch = row // NEXT_N + next_idx = row - batch * NEXT_N + if IS_CONTEXT_LENS_2D: + context_len = tl.load( + context_lens_ptr + batch * context_stride_b + next_idx * context_stride_n + ) + valid_limit = context_len + else: + context_len = tl.load(context_lens_ptr + batch * context_stride_b) + valid_limit = context_len - NEXT_N + next_idx + 1 + valid_limit = tl.minimum(valid_limit, max_context_len) + + tile_start = logical_block * KV_BLOCK_SIZE + if tile_start >= valid_limit: + return + + physical_block = tl.load( + block_table_ptr + batch * block_table_stride_b + logical_block + ) + value_dim: tl.constexpr = HEAD_DIM // 2 + scale_dim: tl.constexpr = HEAD_DIM // 32 + block_base = kv_cache_ptr + physical_block.to(tl.int64) * KV_CACHE_STRIDE_B + scale_base = (block_base + KV_BLOCK_SIZE * value_dim).to( + tl.pointer_type(tl.int32) + ) + + h = tl.arange(0, NUM_HEADS) + d_byte = tl.arange(0, value_dim) + scale_idx = tl.arange(0, scale_dim) + k_offsets = tl.arange(0, BLOCK_KV) + token_pos = k_offsets + global_k = tile_start + k_offsets + valid_k = (k_offsets < KV_BLOCK_SIZE) & (global_k < valid_limit) + + q_packed = tl.load( + q_ptr + + batch * q_stride_b + + next_idx * q_stride_n + + h[:, None] * q_stride_h + + d_byte[None, :] * q_stride_d, + cache_modifier=".cg", + ).to(tl.uint8) + q_scale_word = tl.load( + q_scale_ptr + + batch * q_scale_stride_b + + next_idx * q_scale_stride_n + + h * q_scale_stride_h, + cache_modifier=".cg", + ) + q_scale = _ue8m0_scale_byte_from_packed_i32( + q_scale_word[:, None], scale_idx[None, :] + ) + + k_packed = tl.load( + block_base + token_pos[None, :] * value_dim + d_byte[:, None], + mask=(k_offsets[None, :] < KV_BLOCK_SIZE), + other=0, + ).to(tl.uint8) + k_scale_word = tl.load(scale_base + token_pos, mask=valid_k, other=0) + k_scale = _ue8m0_scale_byte_from_packed_i32( + k_scale_word[:, None], scale_idx[None, :] + ) + + scores = tl.dot_scaled( + q_packed, + q_scale, + "e2m1", + k_packed, + k_scale, + "e2m1", + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=tl.float32, + ) + scores = tl.maximum(scores, 0.0) + + weights = tl.load( + weights_ptr + row * weights_stride_row + h * weights_stride_h, + cache_modifier=".cg", + ).to(tl.float32) + scores = scores * weights[:, None] + logits = tl.sum(scores, axis=0) + tl.store( + logits_ptr + row * logits_stride_row + global_k, + logits, + mask=valid_k, + ) + + +def rocm_get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, + block_size: int, + num_sms: int, + indices: torch.Tensor | None = None, +) -> torch.Tensor: + """Return a DeepGEMM-compatible metadata tensor for ROCm. + + The Triton kernels below schedule directly from ``context_lens`` and + ``block_table``, so the contents are intentionally unused. Returning the + same ``[num_sms + 1, 2]`` int32 shape keeps callers compatible with + DeepGEMM's API and shape checks. + """ + del block_size, indices + return torch.empty( + (int(num_sms) + 1, 2), dtype=torch.int32, device=context_lens.device + ) + + +@triton.jit +def _fp4_mqa_logits_kernel( + q_ptr, + q_scale_ptr, + k_ptr, + k_scale_ptr, + weights_ptr, + cu_start_ptr, + cu_end_ptr, + logits_ptr, + seq_len_kv, + q_stride_s, + q_stride_h, + q_stride_d, + q_scale_stride_s, + q_scale_stride_h, + k_stride_s, + k_stride_d, + weights_stride_s, + weights_stride_h, + logits_stride_s, + NUM_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_KV: tl.constexpr, + USE_DOT_SCALED: tl.constexpr, +): + row = tl.program_id(0) + block_idx = tl.program_id(1) + + start = tl.load(cu_start_ptr + row) + end = tl.load(cu_end_ptr + row) + start = tl.maximum(start, 0) + end = tl.minimum(end, seq_len_kv) + + tile_start = start + block_idx * BLOCK_KV + if tile_start >= end: + return + + k_offsets = tile_start + tl.arange(0, BLOCK_KV) + valid_k = k_offsets < end + + h = tl.arange(0, NUM_HEADS) + if USE_DOT_SCALED: + value_dim: tl.constexpr = HEAD_DIM // 2 + scale_dim: tl.constexpr = HEAD_DIM // 32 + d_byte = tl.arange(0, value_dim) + scale_idx = tl.arange(0, scale_dim) + + q_packed = tl.load( + q_ptr + + row * q_stride_s + + h[:, None] * q_stride_h + + d_byte[None, :] * q_stride_d, + cache_modifier=".cg", + ).to(tl.uint8) + q_scale_word = tl.load( + q_scale_ptr + row * q_scale_stride_s + h * q_scale_stride_h, + cache_modifier=".cg", + ) + q_scale = _ue8m0_scale_byte_from_packed_i32( + q_scale_word[:, None], scale_idx[None, :] + ) + + k_packed = tl.load( + k_ptr + k_offsets[None, :] * k_stride_s + d_byte[:, None] * k_stride_d, + mask=valid_k[None, :], + other=0, + ).to(tl.uint8) + k_scale_word = tl.load(k_scale_ptr + k_offsets, mask=valid_k, other=0) + k_scale = _ue8m0_scale_byte_from_packed_i32( + k_scale_word[:, None], scale_idx[None, :] + ) + + scores = tl.dot_scaled( + q_packed, + q_scale, + "e2m1", + k_packed, + k_scale, + "e2m1", + lhs_k_pack=True, + rhs_k_pack=True, + out_dtype=tl.float32, + ) + else: + d = tl.arange(0, HEAD_DIM) + d_byte = d // 2 + d_is_odd = (d & 1) != 0 + d_scale_block = d // 32 + + q_packed = tl.load( + q_ptr + + row * q_stride_s + + h[:, None] * q_stride_h + + d_byte[None, :] * q_stride_d, + cache_modifier=".cg", + ).to(tl.uint32) + q_nibble = tl.where(d_is_odd[None, :], q_packed >> 4, q_packed) & 0xF + q_values = _fp4_e2m1_to_f32(q_nibble) + q_scale_word = tl.load( + q_scale_ptr + row * q_scale_stride_s + h * q_scale_stride_h, + cache_modifier=".cg", + ) + q_scale = _ue8m0_scale_from_packed_i32( + q_scale_word[:, None], d_scale_block[None, :] + ) + q_values = q_values * q_scale + + k_packed = tl.load( + k_ptr + k_offsets[None, :] * k_stride_s + d_byte[:, None] * k_stride_d, + mask=valid_k[None, :], + other=0, + ).to(tl.uint32) + k_nibble = tl.where(d_is_odd[:, None], k_packed >> 4, k_packed) & 0xF + k_values = _fp4_e2m1_to_f32(k_nibble) + k_scale_word = tl.load(k_scale_ptr + k_offsets, mask=valid_k, other=0) + k_scale = _ue8m0_scale_from_packed_i32( + k_scale_word[None, :], d_scale_block[:, None] + ) + k_values = k_values * k_scale + scores = tl.dot(q_values, k_values, input_precision="ieee") + scores = tl.maximum(scores, 0.0) + + weights = tl.load( + weights_ptr + row * weights_stride_s + h * weights_stride_h, + cache_modifier=".cg", + ).to(tl.float32) + logits = tl.sum(scores * weights[:, None], axis=0) + tl.store( + logits_ptr + row * logits_stride_s + k_offsets, + logits, + mask=valid_k, + ) + + +def _fp8_is_fnuz() -> bool: + fnuz_dtype = getattr(torch, "float8_e4m3fnuz", None) + return fnuz_dtype is not None and current_platform.fp8_dtype() == fnuz_dtype + + +def _block_kv_for_dot(block_size: int) -> int: + return max(16, triton.next_power_of_2(block_size)) + + +def _num_logical_blocks_for_launch( + max_context_len: int, + block_size: int, + block_tables: torch.Tensor, +) -> int: + max_blocks_from_context = triton.cdiv(max_context_len, block_size) + return min(max_blocks_from_context, int(block_tables.shape[1])) + + +def rocm_fp4_mqa_logits( + q: tuple[torch.Tensor, torch.Tensor], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool = False, +) -> torch.Tensor: + """Compute non-paged MXFP4 MQA logits for ROCm sparse-indexer prefill.""" + q_values, q_scales = q + k_values, k_scales = kv + seq_len, num_heads, packed_dim = q_values.shape + seq_len_kv = k_values.shape[0] + head_dim = packed_dim * 2 + assert head_dim == 128, f"MXFP4 MQA logits expects head_dim=128, got {head_dim}" + assert q_scales.shape == (seq_len, num_heads) + assert k_values.shape == (seq_len_kv, packed_dim) + assert k_scales.shape == (seq_len_kv,) + assert q_scales.dtype == torch.int32 + assert k_scales.dtype == torch.int32 + assert weights.shape == (seq_len, num_heads) + assert weights.dtype == torch.float32 + + logits_shape = (seq_len, seq_len_kv) + if clean_logits: + logits = torch.full( + logits_shape, + float("-inf"), + dtype=torch.float32, + device=q_values.device, + ) + else: + logits = torch.empty( + logits_shape, + dtype=torch.float32, + device=q_values.device, + ) + + valid_lens = (cu_seqlen_ke - cu_seqlen_ks).clamp(min=0) + max_valid_len = int(valid_lens.max().item()) if seq_len > 0 else 0 + if max_valid_len == 0 or seq_len_kv == 0: + return logits + + block_kv = 64 + grid = (seq_len, triton.cdiv(max_valid_len, block_kv)) + _fp4_mqa_logits_kernel[grid]( + q_values, + q_scales, + k_values, + k_scales, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + seq_len_kv, + q_values.stride(0), + q_values.stride(1), + q_values.stride(2), + q_scales.stride(0), + q_scales.stride(1), + k_values.stride(0), + k_values.stride(1), + weights.stride(0), + weights.stride(1), + logits.stride(0), + NUM_HEADS=num_heads, + HEAD_DIM=head_dim, + BLOCK_KV=block_kv, + USE_DOT_SCALED=num_heads % 32 == 0, + num_warps=4, + ) + return logits + + +def rocm_fp8_fp4_paged_mqa_logits( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_context_len: int, + clean_logits: bool = False, + logits_dtype: torch.dtype = torch.float32, + indices: torch.Tensor | None = None, +) -> torch.Tensor: + """Compute DeepGEMM-compatible paged MQA logits on ROCm.""" + del schedule_metadata, indices + + q_values, q_scales = q + is_fp4 = q_scales is not None + assert kv_cache.dtype == torch.uint8, ( + f"kv_cache must be uint8, got {kv_cache.dtype}" + ) + assert context_lens.dtype == torch.int32, ( + f"context_lens must be int32, got {context_lens.dtype}" + ) + assert block_tables.dtype == torch.int32, ( + f"block_tables must be int32, got {block_tables.dtype}" + ) + assert weights.dtype == torch.float32, ( + f"weights must be float32, got {weights.dtype}" + ) + assert context_lens.dim() in (1, 2), ( + f"context_lens must be 1D or 2D, got {context_lens.shape}" + ) + + if is_fp4: + assert q_scales is not None + batch_size, next_n, num_heads, packed_dim = q_values.shape + head_dim = packed_dim * 2 + assert q_scales.dtype == torch.int32, ( + f"MXFP4 q scales must be int32, got {q_scales.dtype}" + ) + assert q_scales.shape == (batch_size, next_n, num_heads), ( + f"Expected q scales {(batch_size, next_n, num_heads)}, got {q_scales.shape}" + ) + else: + batch_size, next_n, num_heads, head_dim = q_values.shape + + assert head_dim in (32, 64, 128), f"unsupported head_dim={head_dim}" + assert num_heads in (32, 64), f"unsupported num_heads={num_heads}" + assert weights.shape == (batch_size * next_n, num_heads), ( + f"Expected weights {(batch_size * next_n, num_heads)}, got {weights.shape}" + ) + if context_lens.dim() == 2: + assert context_lens.shape == (batch_size, next_n), ( + f"Expected context_lens {(batch_size, next_n)}, got {context_lens.shape}" + ) + else: + assert context_lens.shape == (batch_size,), ( + f"Expected context_lens {(batch_size,)}, got {context_lens.shape}" + ) + assert block_tables.shape[0] == batch_size, ( + f"Expected {batch_size} block-table rows, got {block_tables.shape[0]}" + ) + assert kv_cache.dim() == 4 and kv_cache.shape[2] == 1, ( + "kv_cache must have shape [num_blocks, block_size, 1, width], got " + f"{kv_cache.shape}" + ) + + block_size = int(kv_cache.shape[1]) + expected_cache_width = head_dim // 2 + 4 if is_fp4 else head_dim + 4 + assert kv_cache.shape[3] == expected_cache_width, ( + f"Expected kv_cache width {expected_cache_width}, got {kv_cache.shape[3]}" + ) + assert logits_dtype in (torch.float32, torch.bfloat16), ( + f"logits_dtype must be float32 or bfloat16, got {logits_dtype}" + ) + + logits = torch.empty( + (batch_size * next_n, max_context_len), + dtype=logits_dtype, + device=q_values.device, + ) + if clean_logits: + logits.fill_(float("-inf")) + if max_context_len == 0: + return logits + + num_logical_blocks = _num_logical_blocks_for_launch( + max_context_len, block_size, block_tables + ) + if num_logical_blocks == 0: + return logits + + grid = (batch_size * next_n, num_logical_blocks) + block_kv = _block_kv_for_dot(block_size) + is_context_lens_2d = context_lens.dim() == 2 + + if is_fp4: + assert q_scales is not None + _fp4_paged_mqa_logits_kernel[grid]( + q_values, + q_scales, + kv_cache, + weights, + context_lens, + block_tables, + logits, + q_values.stride(0), + q_values.stride(1), + q_values.stride(2), + q_values.stride(3), + q_scales.stride(0), + q_scales.stride(1), + q_scales.stride(2), + weights.stride(0), + weights.stride(1), + context_lens.stride(0), + context_lens.stride(1) if is_context_lens_2d else 0, + block_tables.stride(0), + logits.stride(0), + max_context_len, + KV_CACHE_STRIDE_B=kv_cache.stride(0), + KV_BLOCK_SIZE=block_size, + NEXT_N=next_n, + NUM_HEADS=num_heads, + HEAD_DIM=head_dim, + BLOCK_KV=block_kv, + IS_CONTEXT_LENS_2D=is_context_lens_2d, + num_warps=4, + ) + else: + _fp8_paged_mqa_logits_kernel[grid]( + q_values, + kv_cache, + weights, + context_lens, + block_tables, + logits, + q_values.stride(0), + q_values.stride(1), + q_values.stride(2), + q_values.stride(3), + weights.stride(0), + weights.stride(1), + context_lens.stride(0), + context_lens.stride(1) if is_context_lens_2d else 0, + block_tables.stride(0), + logits.stride(0), + max_context_len, + KV_CACHE_STRIDE_B=kv_cache.stride(0), + KV_BLOCK_SIZE=block_size, + NEXT_N=next_n, + NUM_HEADS=num_heads, + HEAD_DIM=head_dim, + BLOCK_KV=block_kv, + IS_CONTEXT_LENS_2D=is_context_lens_2d, + IS_FNUZ=_fp8_is_fnuz(), + num_warps=4, + ) + return logits From fc065f3109d9ecde18f31ba69b96d7277c1775ca Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sat, 16 May 2026 23:42:31 -0500 Subject: [PATCH 5/9] optimized fp4indexer Signed-off-by: tjtanaa --- .../layers/deepseek_v4_attention.py | 104 ++++- vllm/v1/attention/backends/mla/indexer.py | 12 + .../mla/rocm_aiter_mla_sparse_dsv4.py | 328 +++++++++++++-- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 378 +++++++++++++----- .../ops/rocm_fp8_fp4_paged_mqa_logits.py | 45 ++- 5 files changed, 707 insertions(+), 160 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 32d7907cc7fe..b16355f9fc4d 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -350,7 +350,11 @@ def forward( return self.wo_b(z.flatten(1)) - def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: + def attn_gemm_parallel_execute( + self, + hidden_states: torch.Tensor, + skip_indexer_scores: bool = False, + ) -> tuple[Any, ...]: aux_streams = self.aux_stream_list if aux_streams is not None: assert len(aux_streams) >= 3 @@ -390,7 +394,8 @@ def indexer_compressor_kv_score() -> torch.Tensor: out_dtype=torch.float32, ) - aux_fns[1] = indexer_weights_proj + if not skip_indexer_scores: + aux_fns[1] = indexer_weights_proj aux_fns[2] = indexer_compressor_kv_score def fused_wqa_wkv() -> torch.Tensor: @@ -410,6 +415,43 @@ def fused_wqa_wkv() -> torch.Tensor: return qr_kv, kv_score, indexer_kv_score, indexer_weights + def _can_skip_c4a_indexer( + self, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + ) -> bool: + if ( + not current_platform.is_rocm() + or self.compress_ratio != 4 + or self.indexer is None + or not isinstance(attn_metadata, dict) + ): + return False + + indexer = self.indexer + if not getattr(indexer, "use_fp4_kv", False): + return False + + metadata = attn_metadata.get(indexer.k_cache.prefix) + if metadata is None: + return False + + topk_tokens = indexer.topk_tokens + if metadata.num_prefills > 0: + prefill = metadata.prefill + if prefill is None: + return False + if any(chunk.max_valid_len > topk_tokens for chunk in prefill.chunks): + return False + + if metadata.num_decodes > 0: + decode = metadata.decode + if decode is None or decode.max_seq_len > topk_tokens: + return False + + return True + def attention_impl( self, hidden_states: torch.Tensor, @@ -418,9 +460,14 @@ def attention_impl( ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + skip_indexer_scores = self._can_skip_c4a_indexer(attn_metadata) + self.mla_attn.use_trivial_c4a_topk = skip_indexer_scores qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states) + self.attn_gemm_parallel_execute( + hidden_states, + skip_indexer_scores=skip_indexer_scores, + ) ) qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) @@ -451,20 +498,42 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: compressor(kv_score, positions, self.rotary_emb) return q - q, _ = maybe_execute_in_parallel( - wq_b_kv_insert_and_compress, - lambda: indexer( - hidden_states, - qr, - indexer_kv_score, - indexer_weights, - positions, - self.indexer_rotary_emb, - ), - self.ln_events[0], - self.ln_events[1], - aux_stream, - ) + if skip_indexer_scores: + assert indexer_kv_score is not None + + def compress_indexer_kv_cache() -> torch.Tensor: + indexer.compressor( + indexer_kv_score, + positions, + self.indexer_rotary_emb, + ) + assert indexer.topk_indices_buffer is not None + return indexer.topk_indices_buffer + + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert_and_compress, + compress_indexer_kv_cache, + self.ln_events[0], + self.ln_events[1], + aux_stream, + ) + else: + assert indexer_kv_score is not None + assert indexer_weights is not None + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert_and_compress, + lambda: indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + ), + self.ln_events[0], + self.ln_events[1], + aux_stream, + ) elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. aux_stream = ( @@ -656,6 +725,7 @@ def __init__( self.rope_head_dim = qk_rope_head_dim self.indexer = indexer self.topk_indices_buffer = topk_indices_buffer + self.use_trivial_c4a_topk = False self.prefix = prefix # Alias for compatibility with compressor diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index e58a2cde7bb2..0b373742707f 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -177,6 +177,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: token_end: int num_reqs: int skip_kv_gather: bool = False + max_valid_len: int = 0 @dataclass @@ -195,6 +196,7 @@ class DeepSeekV32IndexerDecodeMetadata: decode_lens: torch.Tensor requires_padding: bool schedule_metadata: torch.Tensor + max_seq_len: int = 0 @dataclass @@ -571,6 +573,13 @@ def build( seq_lens = common_attn_metadata.seq_lens[:num_decodes] block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...] + seq_lens_cpu_upper_bound = common_attn_metadata.seq_lens_cpu_upper_bound + if seq_lens_cpu_upper_bound is not None: + max_decode_seq_len = int( + seq_lens_cpu_upper_bound[:num_decodes].max().item() + ) + else: + max_decode_seq_len = common_attn_metadata.max_seq_len max_decode_len = int(decode_lens_cpu.max().item()) next_n = 1 + self.num_speculative_tokens @@ -608,6 +617,7 @@ def build( ) self.expanded_seq_lens_buffer[num_decodes:num_decode_tokens] = 0 seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens] + max_decode_seq_len //= self.compress_ratio # Non-MTP: deep_gemm paged MQA logits requires 2D context_lens # (csrc/apis/attention.hpp). Unsqueeze to (B, 1) so downstream @@ -629,6 +639,7 @@ def build( decode_lens=decode_lens, requires_padding=requires_padding, schedule_metadata=self.scheduler_metadata_buffer, + max_seq_len=max_decode_seq_len, ) attn_metadata = DeepseekV32IndexerMetadata( @@ -722,6 +733,7 @@ def build_prefill_chunk_metadata( token_end=token_end, num_reqs=num_reqs, skip_kv_gather=skip_kv_gather, + max_valid_len=int(compressed_seq_lens_cpu[start_idx:end_idx].max().item()), ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py index e7a82a1ee288..043f92104507 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py @@ -155,6 +155,109 @@ def compute_global_topk_ragged_indices_and_indptr( return global_topk_ragged, topk_indptr, topk_lens +@triton.jit +def _compute_trivial_c4a_decode_lens_kernel( + topk_lens_ptr, + query_start_loc_ptr, + seq_lens_ptr, + token_to_req_indices_ptr, + is_valid_token_ptr, + topk: tl.constexpr, + compress_ratio: tl.constexpr, +): + token_idx = tl.program_id(0) + is_valid = tl.load(is_valid_token_ptr + token_idx) + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + + query_start = tl.load(query_start_loc_ptr + req_idx) + query_end = tl.load(query_start_loc_ptr + req_idx + 1) + query_len = query_end - query_start + seq_len = tl.load(seq_lens_ptr + req_idx) + prefix_len = seq_len - query_len + pos = prefix_len + token_idx - query_start + compressed_len = (pos + 1) // compress_ratio + compressed_len = tl.minimum(compressed_len, topk) + tl.store(topk_lens_ptr + token_idx, tl.where(is_valid, compressed_len, 0)) + + +@triton.jit +def _pack_trivial_c4a_global_topk_ragged_kernel( + global_topk_ragged_ptr, + topk_indptr_ptr, + topk_lens_ptr, + token_to_req_indices_ptr, + block_table_ptr, + block_table_stride, + block_size, + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + block_idx = tl.program_id(1) + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + out_start = tl.load(topk_indptr_ptr + token_idx) + out_len = tl.load(topk_lens_ptr + token_idx) + if block_idx * BLOCK_SIZE >= out_len: + return + + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + valid = offset < out_len + block_indices = offset // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=valid, + other=0, + ) + block_offsets = offset % block_size + slot_ids = block_numbers * block_size + block_offsets + tl.store(global_topk_ragged_ptr + out_start + offset, slot_ids, mask=valid) + + +def compute_trivial_c4a_global_topk_ragged_indices_and_indptr( + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + is_valid_token: torch.Tensor, + compress_ratio: int, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_tokens = token_to_req_indices.shape[0] + topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=seq_lens.device) + _compute_trivial_c4a_decode_lens_kernel[(num_tokens,)]( + topk_lens, + query_start_loc, + seq_lens, + token_to_req_indices, + is_valid_token, + topk=topk, + compress_ratio=compress_ratio, + ) + + topk_indptr = _build_indptr_from_lengths(topk_lens) + global_topk_ragged = torch.empty( + num_tokens * topk, + dtype=torch.int32, + device=seq_lens.device, + ) + if global_topk_ragged.numel() > 0: + block = 256 + _pack_trivial_c4a_global_topk_ragged_kernel[ + (num_tokens, triton.cdiv(topk, block)) + ]( + global_topk_ragged, + topk_indptr, + topk_lens, + token_to_req_indices, + block_table, + block_table.stride(0), + block_size, + BLOCK_SIZE=block, + ) + return global_topk_ragged, topk_indptr, topk_lens + + @triton.jit def _compute_combined_lens_kernel( combined_lens_ptr, @@ -245,6 +348,60 @@ def _combine_topk_swa_indices_ragged_kernel( ) +@triton.jit +def _combine_trivial_topk_swa_indices_ragged_kernel( + combined_ragged_ptr, + combined_indptr_ptr, + query_start_loc_ptr, + seq_lens_ptr, + gather_lens_ptr, + M, + N, + TOP_K: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + WINDOW_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + worker_id = tl.program_id(1) + block_idx = tl.program_id(2) + num_workers = tl.num_programs(1) + + base = tl.load(query_start_loc_ptr) + query_start = tl.load(query_start_loc_ptr + batch_idx) - base + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - base + query_len = query_end - query_start + seq_len = tl.load(seq_lens_ptr + batch_idx) + gather_len = tl.load(gather_lens_ptr + batch_idx) + start_pos = seq_len - query_len + gather_start = seq_len - gather_len + + for token_idx in range(query_start + worker_id, query_end, num_workers): + token_idx_in_query = token_idx - query_start + pos = start_pos + token_idx_in_query + topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) + swa_len = tl.minimum(pos + 1, WINDOW_SIZE) + combined_len = topk_len + swa_len + + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if block_idx * BLOCK_SIZE < combined_len: + out_start = tl.load(combined_indptr_ptr + token_idx) + topk_mask = offset < topk_len + tl.store( + combined_ragged_ptr + out_start + offset, + offset + M * batch_idx, + mask=topk_mask, + ) + + swa_offset = offset - topk_len + swa_mask = (offset >= topk_len) & (swa_offset < swa_len) + tl.store( + combined_ragged_ptr + out_start + offset, + M * batch_idx + N + swa_offset + pos - swa_len + 1 - gather_start, + mask=swa_mask, + ) + + def combine_topk_swa_indices_ragged( topk_indices: torch.Tensor, query_start_loc: torch.Tensor, @@ -302,6 +459,56 @@ def combine_topk_swa_indices_ragged( return combined_ragged, combined_indptr, combined_lens +def combine_trivial_topk_swa_indices_ragged( + num_tokens: int, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + window_size: int, + compress_ratio: int, + topk: int, + M: int, + N: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_reqs = seq_lens.shape[0] + combined_lens = torch.empty(num_tokens, dtype=torch.int32, device=seq_lens.device) + + num_workers = 128 + _compute_combined_lens_kernel[(num_reqs, num_workers)]( + combined_lens, + query_start_loc, + seq_lens, + TOP_K=topk, + COMPRESS_RATIO=compress_ratio, + WINDOW_SIZE=window_size, + ) + + combined_indptr = _build_indptr_from_lengths(combined_lens) + combined_ragged = torch.empty( + num_tokens * (topk + window_size), + dtype=torch.int32, + device=seq_lens.device, + ) + if combined_ragged.numel() > 0: + block = 256 + _combine_trivial_topk_swa_indices_ragged_kernel[ + (num_reqs, num_workers, triton.cdiv(topk + window_size, block)) + ]( + combined_ragged, + combined_indptr, + query_start_loc, + seq_lens, + gather_lens, + M, + N, + TOP_K=topk, + COMPRESS_RATIO=compress_ratio, + WINDOW_SIZE=window_size, + BLOCK_SIZE=block, + ) + return combined_ragged, combined_indptr, combined_lens + + def _copy_ragged_to_graph_buffers( ragged_indices: torch.Tensor, ragged_indptr: torch.Tensor, @@ -573,18 +780,37 @@ def _forward_decode( block_size = attn_metadata.block_size // layer.compress_ratio is_valid = swa_metadata.is_valid_token[:num_decode_tokens] if layer.compress_ratio == 4: - assert layer.topk_indices_buffer is not None - ( - topk_ragged_indices, - topk_ragged_indptr, - topk_lens, - ) = compute_global_topk_ragged_indices_and_indptr( - layer.topk_indices_buffer[:num_decode_tokens], - swa_metadata.token_to_req_indices, - attn_metadata.block_table[:num_decodes], - block_size, - is_valid, - ) + if getattr(layer, "use_trivial_c4a_topk", False): + assert swa_metadata.query_start_loc is not None + assert swa_metadata.seq_lens is not None + assert swa_metadata.token_to_req_indices is not None + ( + topk_ragged_indices, + topk_ragged_indptr, + topk_lens, + ) = compute_trivial_c4a_global_topk_ragged_indices_and_indptr( + swa_metadata.query_start_loc, + swa_metadata.seq_lens, + swa_metadata.token_to_req_indices[:num_decode_tokens], + attn_metadata.block_table[:num_decodes], + block_size, + is_valid, + layer.compress_ratio, + layer.indexer.topk_tokens, + ) + else: + assert layer.topk_indices_buffer is not None + ( + topk_ragged_indices, + topk_ragged_indptr, + topk_lens, + ) = compute_global_topk_ragged_indices_and_indptr( + layer.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + block_size, + is_valid, + ) else: topk_indices = attn_metadata.c128a_global_decode_topk_indices topk_lens = attn_metadata.c128a_decode_topk_lens @@ -642,16 +868,21 @@ def _forward_prefill( assert query_start_loc is not None prefill_token_base = query_start_loc_cpu[num_decodes] + topk_indices = None if not swa_only: if layer.compress_ratio == 4: - assert layer.topk_indices_buffer is not None - topk_indices = layer.topk_indices_buffer[num_decode_tokens:] - topk_indices = topk_indices[:num_prefill_tokens] + if getattr(layer, "use_trivial_c4a_topk", False): + top_k = layer.indexer.topk_tokens + else: + assert layer.topk_indices_buffer is not None + topk_indices = layer.topk_indices_buffer[num_decode_tokens:] + topk_indices = topk_indices[:num_prefill_tokens] + top_k = topk_indices.shape[-1] else: assert attn_metadata is not None topk_indices = attn_metadata.c128a_prefill_topk_indices - assert topk_indices is not None - top_k = topk_indices.shape[-1] + assert topk_indices is not None + top_k = topk_indices.shape[-1] N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio else: assert layer.topk_indices_buffer is not None @@ -660,17 +891,16 @@ def _forward_prefill( N = 0 M = N + layer.window_size + layer.max_num_batched_tokens - num_chunks = (num_prefills + cls._PREFILL_CHUNK_SIZE - 1) // ( - cls._PREFILL_CHUNK_SIZE - ) + prefill_chunk_size = cls._PREFILL_CHUNK_SIZE + num_chunks = (num_prefills + prefill_chunk_size - 1) // prefill_chunk_size workspace_manager = current_workspace_manager() kv = workspace_manager.get_simultaneous( - ((cls._PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((prefill_chunk_size, M, q.shape[-1]), torch.bfloat16), )[0] for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * cls._PREFILL_CHUNK_SIZE - chunk_end = min(chunk_start + cls._PREFILL_CHUNK_SIZE, num_prefills) + chunk_start = chunk_idx * prefill_chunk_size + chunk_end = min(chunk_start + prefill_chunk_size, num_prefills) chunk_size = chunk_end - chunk_start if not swa_only: assert attn_metadata is not None @@ -704,21 +934,45 @@ def _forward_prefill( query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base ) - combined_ragged_indices, combined_ragged_indptr, combined_lens = ( - combine_topk_swa_indices_ragged( - topk_indices[query_start:query_end], - query_start_loc[ - num_decodes + chunk_start : num_decodes + chunk_end + 1 - ], - seq_lens[chunk_start:chunk_end], - gather_lens[chunk_start:chunk_end], - layer.window_size, - layer.compress_ratio, - top_k, - M, - N, + if ( + layer.compress_ratio == 4 + and getattr(layer, "use_trivial_c4a_topk", False) + ): + combined_ragged_indices, combined_ragged_indptr, combined_lens = ( + combine_trivial_topk_swa_indices_ragged( + query_end - query_start, + query_start_loc[ + num_decodes + chunk_start : num_decodes + chunk_end + 1 + ], + seq_lens[chunk_start:chunk_end], + gather_lens[chunk_start:chunk_end], + layer.window_size, + layer.compress_ratio, + top_k, + M, + N, + ) + ) + else: + assert topk_indices is not None + combined_ragged_indices, combined_ragged_indptr, combined_lens = ( + combine_topk_swa_indices_ragged( + topk_indices[query_start:query_end], + query_start_loc[ + num_decodes + + chunk_start : num_decodes + + chunk_end + + 1 + ], + seq_lens[chunk_start:chunk_end], + gather_lens[chunk_start:chunk_end], + layer.window_size, + layer.compress_ratio, + top_k, + M, + N, + ) ) - ) rocm_sparse_attn_prefill( q=q[query_start:query_end], kv=kv.view(-1, 1, q.shape[-1]), diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 892095c0cd93..516a56f8dfef 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -656,20 +656,133 @@ def _has_vllm_topk_ops() -> bool: ) +@triton.jit +def _fill_prefill_topk_indices_kernel( + row_starts_ptr, + row_ends_ptr, + out_ptr, + out_stride, + topk_tokens: tl.constexpr, + BLOCK_TOPK: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * BLOCK_TOPK + tl.arange(0, BLOCK_TOPK) + row_len = tl.load(row_ends_ptr + row) - tl.load(row_starts_ptr + row) + row_len = tl.maximum(row_len, 0) + values = tl.where(offsets < row_len, offsets, -1) + tl.store( + out_ptr + row * out_stride + offsets, + values, + mask=offsets < topk_tokens, + ) + + +def _fill_prefill_topk_indices( + row_starts: torch.Tensor, + row_ends: torch.Tensor, + out: torch.Tensor, + topk_tokens: int, +) -> torch.Tensor: + block_topk = 256 + grid = (out.shape[0], triton.cdiv(topk_tokens, block_topk)) + _fill_prefill_topk_indices_kernel[grid]( + row_starts, + row_ends, + out, + out.stride(0), + topk_tokens, + BLOCK_TOPK=block_topk, + ) + return out + + +@triton.jit +def _fill_decode_topk_indices_kernel( + seq_lens_ptr, + out_ptr, + seq_lens_stride_b, + seq_lens_stride_n, + out_stride, + topk_tokens: tl.constexpr, + next_n: tl.constexpr, + SEQ_LENS_2D: tl.constexpr, + BLOCK_TOPK: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + batch = row // next_n + next_idx = row - batch * next_n + if SEQ_LENS_2D: + seq_len = tl.load( + seq_lens_ptr + batch * seq_lens_stride_b + next_idx * seq_lens_stride_n + ) + else: + batch_seq_len = tl.load(seq_lens_ptr + batch * seq_lens_stride_b) + seq_len = batch_seq_len - next_n + next_idx + 1 + seq_len = tl.maximum(seq_len, 0) + + offsets = block * BLOCK_TOPK + tl.arange(0, BLOCK_TOPK) + values = tl.where(offsets < seq_len, offsets, -1) + tl.store( + out_ptr + row * out_stride + offsets, + values, + mask=offsets < topk_tokens, + ) + + +def _fill_decode_topk_indices( + seq_lens: torch.Tensor, + topk_tokens: int, + next_n: int, + num_rows: int, + out: torch.Tensor | None = None, +) -> torch.Tensor: + indices = out + if indices is None: + indices = torch.empty( + (num_rows, topk_tokens), + dtype=torch.int32, + device=seq_lens.device, + ) + + block_topk = 256 + grid = (num_rows, triton.cdiv(topk_tokens, block_topk)) + _fill_decode_topk_indices_kernel[grid]( + seq_lens, + indices, + seq_lens.stride(0), + seq_lens.stride(1) if seq_lens.dim() == 2 else 0, + indices.stride(0), + topk_tokens, + next_n, + SEQ_LENS_2D=seq_lens.dim() == 2, + BLOCK_TOPK=block_topk, + ) + return indices + + def _topk_indices_vllm_prefill( logits: torch.Tensor, topk_tokens: int, row_starts: torch.Tensor, row_ends: torch.Tensor, + out: torch.Tensor | None = None, ) -> torch.Tensor: if not _has_vllm_topk_ops(): - return _topk_indices_torch(logits, topk_tokens, row_starts) + indices = _topk_indices_torch(logits, topk_tokens, row_starts) + if out is not None: + out.copy_(indices) + return out + return indices - indices = torch.empty( - (logits.shape[0], topk_tokens), - dtype=torch.int32, - device=logits.device, - ) + indices = out + if indices is None: + indices = torch.empty( + (logits.shape[0], topk_tokens), + dtype=torch.int32, + device=logits.device, + ) torch.ops._C.top_k_per_row_prefill( logits, row_starts, @@ -688,15 +801,22 @@ def _topk_indices_vllm_decode( topk_tokens: int, next_n: int, seq_lens: torch.Tensor, + out: torch.Tensor | None = None, ) -> torch.Tensor: if not _has_vllm_topk_ops(): - return _topk_indices_torch(logits, topk_tokens) + indices = _topk_indices_torch(logits, topk_tokens) + if out is not None: + out.copy_(indices) + return out + return indices - indices = torch.empty( - (logits.shape[0], topk_tokens), - dtype=torch.int32, - device=logits.device, - ) + indices = out + if indices is None: + indices = torch.empty( + (logits.shape[0], topk_tokens), + dtype=torch.int32, + device=logits.device, + ) torch.ops._C.top_k_per_row_decode( logits, next_n, @@ -845,26 +965,32 @@ def rocm_aiter_sparse_attn_indexer_native( scale_fmt, ) - topk_indices_buffer[: hidden_states.shape[0]] = -1 + if hidden_states.shape[0] > num_tokens: + topk_indices_buffer[num_tokens : hidden_states.shape[0], :topk_tokens] = -1 if has_prefill: prefill_metadata = layer_attn_metadata.prefill assert prefill_metadata is not None + prefill_chunks = prefill_metadata.chunks + needs_prefill_logits = not ( + use_fp4_cache + and all(chunk.max_valid_len <= topk_tokens for chunk in prefill_chunks) + ) max_total_seq_lens = max( - (chunk.total_seq_lens for chunk in prefill_metadata.chunks), + (chunk.total_seq_lens for chunk in prefill_chunks), default=0, ) - if use_fp4_cache: - k_quant_full = torch.zeros( + if use_fp4_cache and needs_prefill_logits: + k_quant_full = torch.empty( [max_total_seq_lens, head_dim // 2], device=device, dtype=torch.uint8, ) - k_scale_full = torch.zeros( + k_scale_full = torch.empty( [max_total_seq_lens, head_dim // MXFP4_BLOCK_SIZE], device=device, dtype=torch.uint8, ) - else: + elif not use_fp4_cache: k_quant_full = torch.zeros( [max_total_seq_lens, head_dim], device=device, @@ -875,35 +1001,53 @@ def rocm_aiter_sparse_attn_indexer_native( device=device, dtype=torch.uint8, ) - for chunk in prefill_metadata.chunks: + else: + k_quant_full = None + k_scale_full = None + for chunk in prefill_chunks: + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + if use_fp4_cache and chunk.max_valid_len <= topk_tokens: + _fill_prefill_topk_indices( + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + topk_tokens, + ) + continue + + assert k_quant_full is not None + assert k_scale_full is not None k_quant = k_quant_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens] - if use_fp4_cache: - cp_gather_indexer_mxfp4_cache_triton( - kv_cache, - k_quant, - k_scale, - chunk.block_table, - chunk.cu_seq_lens, - token_to_seq=chunk.token_to_seq, - ) - elif _ON_GFX942: - ops.cp_gather_indexer_k_quant_cache( - kv_cache, - k_quant, - k_scale, - chunk.block_table, - chunk.cu_seq_lens, - ) - else: - cp_gather_indexer_k_quant_cache_triton( - kv_cache, - k_quant, - k_scale, - chunk.block_table, - chunk.cu_seq_lens, - token_to_seq=chunk.token_to_seq, - ) + if not chunk.skip_kv_gather: + if use_fp4_cache: + cp_gather_indexer_mxfp4_cache_triton( + kv_cache, + k_quant, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + token_to_seq=chunk.token_to_seq, + ) + elif _ON_GFX942: + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_quant, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + else: + cp_gather_indexer_k_quant_cache_triton( + kv_cache, + k_quant, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + token_to_seq=chunk.token_to_seq, + ) q_slice = q_values[chunk.token_start : chunk.token_end] weight_slice = weights[chunk.token_start : chunk.token_end] @@ -930,17 +1074,13 @@ def rocm_aiter_sparse_attn_indexer_native( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] if use_fp4_cache: - topk_indices.copy_( - _topk_indices_vllm_prefill( - logits, - topk_tokens, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - ) + _topk_indices_vllm_prefill( + logits, + topk_tokens, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + out=topk_indices, ) else: topk_indices.copy_( @@ -1012,25 +1152,56 @@ def rocm_aiter_sparse_attn_indexer_native( assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] if use_fp4_cache: - use_vllm_topk = _has_vllm_topk_ops() - active_paged_width = ( - decode_metadata.block_table.shape[1] * kv_cache_decode.shape[1] - ) - logits_width = min( - max_model_len, - max(topk_tokens, active_paged_width), - ) - logits = fp8_fp4_paged_mqa_logits( - (padded_q_decode_tokens.view(torch.int8), padded_q_scale), - kv_cache_decode, - weights[:num_padded_tokens], - decode_metadata.seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=logits_width, - clean_logits=not use_vllm_topk, - ) + if decode_metadata.max_seq_len <= topk_tokens: + topk_out = ( + topk_indices + if num_padded_tokens == topk_indices.shape[0] + else None + ) + decoded_topk = _fill_decode_topk_indices( + decode_metadata.seq_lens, + topk_tokens, + next_n, + num_padded_tokens, + out=topk_out, + ) + if topk_out is None: + topk_indices.copy_(decoded_topk[:num_decode_tokens]) + else: + use_vllm_topk = _has_vllm_topk_ops() + active_paged_width = ( + decode_metadata.block_table.shape[1] * kv_cache_decode.shape[1] + ) + logits_width = min( + max_model_len, + max(topk_tokens, active_paged_width), + ) + logits = fp8_fp4_paged_mqa_logits( + (padded_q_decode_tokens.view(torch.int8), padded_q_scale), + kv_cache_decode, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=logits_width, + clean_logits=not use_vllm_topk, + ) + topk_out = ( + topk_indices + if logits.shape[0] == topk_indices.shape[0] + else None + ) + decoded_topk = _topk_indices_vllm_decode( + logits, + topk_tokens, + next_n, + decode_metadata.seq_lens, + out=topk_out, + ) + if topk_out is None: + topk_indices.copy_(decoded_topk[:num_decode_tokens]) else: logits = rocm_fp8_paged_mqa_logits( padded_q_decode_tokens, @@ -1041,18 +1212,6 @@ def rocm_aiter_sparse_attn_indexer_native( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) - - topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - if use_fp4_cache: - topk_indices.copy_( - _topk_indices_vllm_decode( - logits, - topk_tokens, - next_n, - decode_metadata.seq_lens, - )[:num_decode_tokens] - ) - else: topk_indices.copy_( _topk_indices_torch(logits, topk_tokens)[:num_decode_tokens] ) @@ -1666,6 +1825,7 @@ def _rocm_sparse_attn_prefill_ragged_triton( attn_sink: torch.Tensor | None, nope_head_dim: int, rope_head_dim: int, + output: torch.Tensor | None = None, ) -> torch.Tensor: assert q.ndim == 3, f"expected q=[sq,h,d], got {q.shape}" assert kv.ndim == 2, f"expected kv=[skv,d], got {kv.shape}" @@ -1695,7 +1855,17 @@ def _rocm_sparse_attn_prefill_ragged_triton( block_h = 16 block_d = triton.next_power_of_2(head_dim) block_k = 16 if head_dim >= 256 else 32 - out = torch.empty_like(q, dtype=torch.bfloat16) + if output is None: + out = torch.empty_like(q, dtype=torch.bfloat16) + else: + assert output.shape == q.shape, ( + f"expected output shape {q.shape}, got {output.shape}" + ) + assert output.dtype == torch.bfloat16, ( + "direct ROCm sparse attention output currently expects bfloat16, " + f"got {output.dtype}" + ) + out = output _sparse_attn_prefill_ragged_kernel[(num_queries, triton.cdiv(num_heads, block_h))]( q, kv, @@ -1733,6 +1903,7 @@ def _rocm_sparse_attn_prefill_triton( nope_head_dim: int, rope_head_dim: int, topk_length: torch.Tensor | None = None, + output: torch.Tensor | None = None, ) -> torch.Tensor: ragged_indices, ragged_indptr = build_ragged_indices_from_dense( indices, @@ -1750,6 +1921,7 @@ def _rocm_sparse_attn_prefill_triton( attn_sink=attn_sink, nope_head_dim=nope_head_dim, rope_head_dim=rope_head_dim, + output=output, ) @@ -1765,6 +1937,7 @@ def _rocm_sparse_attn_decode_ragged_triton( extra_cache: torch.Tensor | None = None, extra_indices: torch.Tensor | None = None, extra_indptr: torch.Tensor | None = None, + output: torch.Tensor | None = None, ) -> torch.Tensor: assert q.ndim == 3, f"expected q=[b,h,d], got {q.shape}" assert main_cache.ndim == 3, ( @@ -1825,9 +1998,22 @@ def _rocm_sparse_attn_decode_ragged_triton( extra_indices = torch.empty(0, device=q.device, dtype=torch.int32) extra_indptr = torch.zeros(num_queries + 1, device=q.device, dtype=torch.int32) + assert extra_indices is not None + assert extra_indptr is not None + block_h = 16 - block_k = 16 if head_dim >= 256 else 32 - out = torch.empty_like(q, dtype=torch.bfloat16) + block_k = 32 + if output is None: + out = torch.empty_like(q, dtype=torch.bfloat16) + else: + assert output.shape == q.shape, ( + f"expected output shape {q.shape}, got {output.shape}" + ) + assert output.dtype == torch.bfloat16, ( + "direct ROCm sparse attention output currently expects bfloat16, " + f"got {output.dtype}" + ) + out = output _sparse_attn_decode_ragged_kernel[(num_queries, triton.cdiv(num_heads, block_h))]( q, main_cache, @@ -1858,7 +2044,7 @@ def _rocm_sparse_attn_decode_ragged_triton( IS_FNUZ=current_platform.is_fp8_fnuz(), BLOCK_H=block_h, BLOCK_K=block_k, - num_warps=8, + num_warps=4, ) return out @@ -1879,6 +2065,7 @@ def _rocm_sparse_attn_decode_triton( main_ragged_indptr: torch.Tensor | None = None, extra_ragged_indices: torch.Tensor | None = None, extra_ragged_indptr: torch.Tensor | None = None, + output: torch.Tensor | None = None, ) -> torch.Tensor: if main_ragged_indices is None or main_ragged_indptr is None: main_ragged_indices, main_ragged_indptr = build_ragged_indices_from_dense( @@ -1914,6 +2101,7 @@ def _rocm_sparse_attn_decode_triton( extra_cache=extra_cache, extra_indices=extra_ragged_indices, extra_indptr=extra_ragged_indptr, + output=output, ) @@ -1941,6 +2129,7 @@ def rocm_sparse_attn_prefill( "rocm_sparse_attn_prefill", ) if ragged_indices is not None and ragged_indptr is not None: + direct_output = output if output.dtype == torch.bfloat16 else None output_chunk = _rocm_sparse_attn_prefill_ragged_triton( q=q, kv=kv.squeeze(1), @@ -1950,9 +2139,11 @@ def rocm_sparse_attn_prefill( attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], nope_head_dim=nope_head_dim, rope_head_dim=rope_head_dim, + output=direct_output, ) else: indices_2d = indices.reshape(indices.shape[0], -1) + direct_output = output if output.dtype == torch.bfloat16 else None output_chunk = _rocm_sparse_attn_prefill_triton( q=q, kv=kv.squeeze(1), @@ -1962,8 +2153,10 @@ def rocm_sparse_attn_prefill( nope_head_dim=nope_head_dim, rope_head_dim=rope_head_dim, topk_length=topk_length, + output=direct_output, ) - output.copy_(output_chunk.to(output.dtype)) + if output_chunk is not output: + output.copy_(output_chunk.to(output.dtype)) def rocm_sparse_attn_decode( @@ -2014,6 +2207,7 @@ def rocm_sparse_attn_decode( if topk_indices is not None: extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) + direct_output = output if output.dtype == torch.bfloat16 else None attn_out = _rocm_sparse_attn_decode_triton( q=q, main_cache=swa_k_cache, @@ -2030,5 +2224,7 @@ def rocm_sparse_attn_decode( main_ragged_indptr=swa_ragged_indptr, extra_ragged_indices=topk_ragged_indices, extra_ragged_indptr=topk_ragged_indptr, + output=direct_output, ) - output.copy_(attn_out.to(output.dtype)) + if attn_out is not output: + output.copy_(attn_out.to(output.dtype)) diff --git a/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py b/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py index b5e60f488501..290f1b020af3 100644 --- a/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py +++ b/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py @@ -190,10 +190,11 @@ def _fp4_paged_mqa_logits_kernel( NUM_HEADS: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_KV: tl.constexpr, + KV_BLOCKS_PER_PROG: tl.constexpr, IS_CONTEXT_LENS_2D: tl.constexpr, ): row = tl.program_id(0) - logical_block = tl.program_id(1) + logical_block_group = tl.program_id(1) batch = row // NEXT_N next_idx = row - batch * NEXT_N @@ -207,27 +208,34 @@ def _fp4_paged_mqa_logits_kernel( valid_limit = context_len - NEXT_N + next_idx + 1 valid_limit = tl.minimum(valid_limit, max_context_len) - tile_start = logical_block * KV_BLOCK_SIZE + first_logical_block = logical_block_group * KV_BLOCKS_PER_PROG + tile_start = first_logical_block * KV_BLOCK_SIZE if tile_start >= valid_limit: return - physical_block = tl.load( - block_table_ptr + batch * block_table_stride_b + logical_block - ) value_dim: tl.constexpr = HEAD_DIM // 2 scale_dim: tl.constexpr = HEAD_DIM // 32 - block_base = kv_cache_ptr + physical_block.to(tl.int64) * KV_CACHE_STRIDE_B - scale_base = (block_base + KV_BLOCK_SIZE * value_dim).to( - tl.pointer_type(tl.int32) - ) h = tl.arange(0, NUM_HEADS) d_byte = tl.arange(0, value_dim) scale_idx = tl.arange(0, scale_dim) k_offsets = tl.arange(0, BLOCK_KV) - token_pos = k_offsets global_k = tile_start + k_offsets + logical_blocks = first_logical_block + k_offsets // KV_BLOCK_SIZE + token_pos = k_offsets % KV_BLOCK_SIZE valid_k = (k_offsets < KV_BLOCK_SIZE) & (global_k < valid_limit) + if KV_BLOCKS_PER_PROG != 1: + valid_k = global_k < valid_limit + + physical_block = tl.load( + block_table_ptr + batch * block_table_stride_b + logical_blocks, + mask=valid_k, + other=0, + ) + block_base = kv_cache_ptr + physical_block.to(tl.int64) * KV_CACHE_STRIDE_B + scale_base = (block_base + KV_BLOCK_SIZE * value_dim).to( + tl.pointer_type(tl.int32) + ) q_packed = tl.load( q_ptr @@ -250,7 +258,7 @@ def _fp4_paged_mqa_logits_kernel( k_packed = tl.load( block_base + token_pos[None, :] * value_dim + d_byte[:, None], - mask=(k_offsets[None, :] < KV_BLOCK_SIZE), + mask=valid_k[None, :], other=0, ).to(tl.uint8) k_scale_word = tl.load(scale_base + token_pos, mask=valid_k, other=0) @@ -499,7 +507,7 @@ def rocm_fp4_mqa_logits( if max_valid_len == 0 or seq_len_kv == 0: return logits - block_kv = 64 + block_kv = 128 grid = (seq_len, triton.cdiv(max_valid_len, block_kv)) _fp4_mqa_logits_kernel[grid]( q_values, @@ -622,12 +630,16 @@ def rocm_fp8_fp4_paged_mqa_logits( if num_logical_blocks == 0: return logits - grid = (batch_size * next_n, num_logical_blocks) - block_kv = _block_kv_for_dot(block_size) is_context_lens_2d = context_lens.dim() == 2 if is_fp4: assert q_scales is not None + kv_blocks_per_prog = 2 + grid = ( + batch_size * next_n, + triton.cdiv(num_logical_blocks, kv_blocks_per_prog), + ) + block_kv = _block_kv_for_dot(block_size * kv_blocks_per_prog) _fp4_paged_mqa_logits_kernel[grid]( q_values, q_scales, @@ -656,10 +668,13 @@ def rocm_fp8_fp4_paged_mqa_logits( NUM_HEADS=num_heads, HEAD_DIM=head_dim, BLOCK_KV=block_kv, + KV_BLOCKS_PER_PROG=kv_blocks_per_prog, IS_CONTEXT_LENS_2D=is_context_lens_2d, - num_warps=4, + num_warps=8, ) else: + grid = (batch_size * next_n, num_logical_blocks) + block_kv = _block_kv_for_dot(block_size) _fp8_paged_mqa_logits_kernel[grid]( q_values, kv_cache, From f557e351e375ccab9fe89e167793aaedf4d6494a Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sun, 17 May 2026 20:08:10 -0500 Subject: [PATCH 6/9] cleanup Signed-off-by: tjtanaa --- fmaacctests/README.md | 36 ------ fmaacctests/accuracy.py | 211 ------------------------------- fmaacctests/benchmark.py | 239 ------------------------------------ fmaacctests/fma_variants.py | 189 ---------------------------- 4 files changed, 675 deletions(-) delete mode 100644 fmaacctests/README.md delete mode 100644 fmaacctests/accuracy.py delete mode 100644 fmaacctests/benchmark.py delete mode 100644 fmaacctests/fma_variants.py diff --git a/fmaacctests/README.md b/fmaacctests/README.md deleted file mode 100644 index ea1525381a50..000000000000 --- a/fmaacctests/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# FMA Accuracy and Benchmark Checks - -These scripts compare the old fused indexer-Q RoPE arithmetic: - -```python -r_even = x_even * cos - x_odd * sin -r_odd = x_odd * cos + x_even * sin -``` - -against the explicit `tl.fma` version used in the production kernel. - -Run from the vLLM checkout root: - -```bash -uv run --no-project python fmaacctests/accuracy.py -uv run --no-project python fmaacctests/benchmark.py -``` - -The accuracy script compares both variants against the unfused vLLM reference: -`ops.rotary_embedding` followed by `per_token_group_quant_fp8(..., -use_ue8m0=True)`. It returns non-zero only if the `fma` variant mismatches the -reference. - -The benchmark script uses preallocated outputs and CUDA events to time only the -Triton kernel launch/execution path. Example with JSON output: - -```bash -uv run --no-project python fmaacctests/accuracy.py --json fma_accuracy.json -uv run --no-project python fmaacctests/benchmark.py --json fma_benchmark.json -``` - -Observed on the local ROCm MI355X environment, the old `muladd` variant had two -FP8 Q mismatches against the vLLM reference for float32 RoPE caches at token -counts 257 and 1023. The `tl.fma` variant had zero mismatches for all tested -cases. The microbenchmark did not show a speed regression; the largest tested -cases were about 4% faster with `tl.fma`. diff --git a/fmaacctests/accuracy.py b/fmaacctests/accuracy.py deleted file mode 100644 index 6d07c76c8dc8..000000000000 --- a/fmaacctests/accuracy.py +++ /dev/null @@ -1,211 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Check explicit tl.fma accuracy for fused DeepSeek V4 indexer-Q RoPE.""" - -from __future__ import annotations - -import argparse -import json -import sys -from pathlib import Path - -import torch - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -from fma_variants import HEAD_DIM, MAX_POS, N_HEAD, ROPE_DIM, run_variant -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, -) - - -def _dtype(name: str) -> torch.dtype: - if name == "float32": - return torch.float32 - if name == "bfloat16": - return torch.bfloat16 - raise ValueError(f"unsupported dtype: {name}") - - -def make_inputs( - num_tokens: int, - cache_dtype: torch.dtype, - device: str, - seed: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - torch.manual_seed(seed) - q = torch.randn( - num_tokens, - N_HEAD, - HEAD_DIM, - dtype=torch.bfloat16, - device=device, - ) - positions = torch.randint( - 0, - MAX_POS, - (num_tokens,), - dtype=torch.int64, - device=device, - ) - cos_sin_cache = torch.randn( - MAX_POS, - ROPE_DIM, - dtype=cache_dtype, - device=device, - ) - weights = torch.randn( - num_tokens, - N_HEAD, - dtype=torch.bfloat16, - device=device, - ) - return positions, q, cos_sin_cache, weights - - -def reference( - positions: torch.Tensor, - q: torch.Tensor, - cos_sin_cache: torch.Tensor, - weights: torch.Tensor, - softmax_scale: float, - head_scale: float, -) -> tuple[torch.Tensor, torch.Tensor]: - q_rot = q.clone() - ops.rotary_embedding( - positions, - q_rot, - None, - HEAD_DIM, - cos_sin_cache, - False, - HEAD_DIM - ROPE_DIM, - False, - ) - q_fp8, q_scale = per_token_group_quant_fp8( - q_rot.view(-1, HEAD_DIM).contiguous(), - HEAD_DIM, - use_ue8m0=True, - ) - q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM) - q_scale = q_scale.view(-1, N_HEAD) - weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale - return q_fp8, weights_out - - -def compare_case( - num_tokens: int, - cache_dtype_name: str, - device: str, - seed: int, -) -> list[dict[str, object]]: - cache_dtype = _dtype(cache_dtype_name) - positions, q, cos_sin_cache, weights = make_inputs( - num_tokens, - cache_dtype, - device, - seed, - ) - softmax_scale = HEAD_DIM**-0.5 - head_scale = N_HEAD**-0.5 - q_ref, weights_ref = reference( - positions, - q, - cos_sin_cache, - weights, - softmax_scale, - head_scale, - ) - - rows: list[dict[str, object]] = [] - for label, use_fma in (("muladd", False), ("fma", True)): - q_actual, weights_actual = run_variant( - positions, - q, - cos_sin_cache, - weights, - softmax_scale, - head_scale, - use_fma=use_fma, - ) - torch.cuda.synchronize() - ref_bits = q_ref.view(torch.int8) - actual_bits = q_actual.view(torch.int8) - q_mismatches = int((ref_bits != actual_bits).sum().item()) - weight_diff = (weights_ref - weights_actual).abs() - rows.append( - { - "num_tokens": num_tokens, - "cache_dtype": cache_dtype_name, - "variant": label, - "q_mismatches": q_mismatches, - "q_elements": ref_bits.numel(), - "weights_equal": bool(torch.equal(weights_ref, weights_actual)), - "max_weight_abs_diff": float(weight_diff.max().item()), - } - ) - return rows - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cuda") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument( - "--tokens", - type=int, - nargs="+", - default=[1, 7, 32, 257, 1023], - ) - parser.add_argument( - "--cache-dtypes", - nargs="+", - default=["float32", "bfloat16"], - choices=["float32", "bfloat16"], - ) - parser.add_argument("--json", type=Path) - parser.add_argument( - "--no-strict-fma", - action="store_true", - help="Do not return non-zero if the fma variant mismatches reference.", - ) - return parser.parse_args() - - -def main() -> int: - args = parse_args() - all_rows: list[dict[str, object]] = [] - for num_tokens in args.tokens: - for cache_dtype_name in args.cache_dtypes: - all_rows.extend( - compare_case(num_tokens, cache_dtype_name, args.device, args.seed) - ) - - print( - "tokens cache_dtype variant q_mismatches/q_elements " - "weights_equal max_weight_abs_diff" - ) - for row in all_rows: - print( - f"{row['num_tokens']:>6} {row['cache_dtype']:<8} " - f"{row['variant']:<6} {row['q_mismatches']}/{row['q_elements']} " - f"{row['weights_equal']} {row['max_weight_abs_diff']:.9g}" - ) - - if args.json is not None: - args.json.write_text(json.dumps(all_rows, indent=2) + "\n") - - if not args.no_strict_fma: - for row in all_rows: - if ( - row["variant"] == "fma" - and (row["q_mismatches"] != 0 or not row["weights_equal"]) - ): - return 1 - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/fmaacctests/benchmark.py b/fmaacctests/benchmark.py deleted file mode 100644 index de81846624ef..000000000000 --- a/fmaacctests/benchmark.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Benchmark old mul/add RoPE arithmetic against explicit tl.fma.""" - -from __future__ import annotations - -import argparse -import json -import statistics -import sys -from pathlib import Path - -import torch - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -from fma_variants import HEAD_DIM, MAX_POS, N_HEAD, ROPE_DIM, launch_variant - - -def _dtype(name: str) -> torch.dtype: - if name == "float32": - return torch.float32 - if name == "bfloat16": - return torch.bfloat16 - raise ValueError(f"unsupported dtype: {name}") - - -def make_inputs( - num_tokens: int, - cache_dtype: torch.dtype, - device: str, - seed: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - torch.manual_seed(seed) - q = torch.randn( - num_tokens, - N_HEAD, - HEAD_DIM, - dtype=torch.bfloat16, - device=device, - ) - positions = torch.randint( - 0, - MAX_POS, - (num_tokens,), - dtype=torch.int64, - device=device, - ) - cos_sin_cache = torch.randn( - MAX_POS, - ROPE_DIM, - dtype=cache_dtype, - device=device, - ) - weights = torch.randn( - num_tokens, - N_HEAD, - dtype=torch.bfloat16, - device=device, - ) - return positions, q, cos_sin_cache, weights - - -def time_variant( - positions: torch.Tensor, - q: torch.Tensor, - cos_sin_cache: torch.Tensor, - weights: torch.Tensor, - *, - use_fma: bool, - warmup: int, - iters: int, - repeats: int, -) -> dict[str, float]: - softmax_scale = HEAD_DIM**-0.5 - head_scale = N_HEAD**-0.5 - q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fn) - weights_out = torch.empty_like(weights, dtype=torch.float32) - - for _ in range(warmup): - launch_variant( - positions, - q, - cos_sin_cache, - weights, - softmax_scale, - head_scale, - q_fp8, - weights_out, - use_fma=use_fma, - ) - torch.cuda.synchronize() - - samples: list[float] = [] - for _ in range(repeats): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - launch_variant( - positions, - q, - cos_sin_cache, - weights, - softmax_scale, - head_scale, - q_fp8, - weights_out, - use_fma=use_fma, - ) - end.record() - torch.cuda.synchronize() - samples.append(start.elapsed_time(end) / iters) - - return { - "mean_ms": statistics.fmean(samples), - "median_ms": statistics.median(samples), - "min_ms": min(samples), - "max_ms": max(samples), - } - - -def run_case( - num_tokens: int, - cache_dtype_name: str, - device: str, - seed: int, - warmup: int, - iters: int, - repeats: int, -) -> dict[str, object]: - cache_dtype = _dtype(cache_dtype_name) - positions, q, cos_sin_cache, weights = make_inputs( - num_tokens, - cache_dtype, - device, - seed, - ) - - # Compile both specializations before timing either one. - for use_fma in (False, True): - q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fn) - weights_out = torch.empty_like(weights, dtype=torch.float32) - launch_variant( - positions, - q, - cos_sin_cache, - weights, - HEAD_DIM**-0.5, - N_HEAD**-0.5, - q_fp8, - weights_out, - use_fma=use_fma, - ) - torch.cuda.synchronize() - - muladd = time_variant( - positions, - q, - cos_sin_cache, - weights, - use_fma=False, - warmup=warmup, - iters=iters, - repeats=repeats, - ) - fma = time_variant( - positions, - q, - cos_sin_cache, - weights, - use_fma=True, - warmup=warmup, - iters=iters, - repeats=repeats, - ) - delta_pct = (fma["median_ms"] / muladd["median_ms"] - 1.0) * 100.0 - return { - "num_tokens": num_tokens, - "cache_dtype": cache_dtype_name, - "muladd": muladd, - "fma": fma, - "fma_vs_muladd_median_delta_pct": delta_pct, - } - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cuda") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--tokens", type=int, nargs="+", default=[257, 1023, 4096]) - parser.add_argument( - "--cache-dtypes", - nargs="+", - default=["float32", "bfloat16"], - choices=["float32", "bfloat16"], - ) - parser.add_argument("--warmup", type=int, default=20) - parser.add_argument("--iters", type=int, default=100) - parser.add_argument("--repeats", type=int, default=5) - parser.add_argument("--json", type=Path) - return parser.parse_args() - - -def main() -> int: - args = parse_args() - rows: list[dict[str, object]] = [] - for num_tokens in args.tokens: - for cache_dtype_name in args.cache_dtypes: - rows.append( - run_case( - num_tokens, - cache_dtype_name, - args.device, - args.seed, - args.warmup, - args.iters, - args.repeats, - ) - ) - - print("tokens cache_dtype muladd_median_ms fma_median_ms delta_pct") - for row in rows: - muladd = row["muladd"] - fma = row["fma"] - print( - f"{row['num_tokens']:>6} {row['cache_dtype']:<8} " - f"{muladd['median_ms']:.6f} {fma['median_ms']:.6f} " - f"{row['fma_vs_muladd_median_delta_pct']:+.3f}" - ) - - if args.json is not None: - args.json.write_text(json.dumps(rows, indent=2) + "\n") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/fmaacctests/fma_variants.py b/fmaacctests/fma_variants.py deleted file mode 100644 index 3b8c4bd2bd90..000000000000 --- a/fmaacctests/fma_variants.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Fused indexer-Q RoPE kernels used by reviewer accuracy/perf scripts.""" - -from __future__ import annotations - -import torch - -from vllm.triton_utils import tl, triton - -HEAD_DIM = 128 -ROPE_DIM = 64 -N_HEAD = 64 -MAX_POS = 4096 - - -@triton.jit -def _get_cos_sin( - cos_sin_cache_ptr, - cos_sin_cache_stride, - pos, - HALF_ROT_DIM: tl.constexpr, -): - block = tl.arange(0, HALF_ROT_DIM) - cos = tl.load(cos_sin_cache_ptr + pos * cos_sin_cache_stride + block) - cos = cos.to(tl.float32) - sin = tl.load(cos_sin_cache_ptr + pos * cos_sin_cache_stride + block + HALF_ROT_DIM) - sin = sin.to(tl.float32) - return cos, sin - - -@triton.jit -def _fused_indexer_q_rope_quant_variant_kernel( - pos_ptr, - index_q_ptr, - index_q_stride0, - index_q_stride1, - index_q_cos_sin_ptr, - index_q_cos_sin_stride, - INDEX_Q_HALF_ROT_DIM: tl.constexpr, - index_q_fp8_ptr, - index_q_fp8_stride0, - index_q_fp8_stride1, - INDEX_Q_HEAD_DIM: tl.constexpr, - index_weights_ptr, - index_weights_stride, - index_weights_softmax_scale, - index_weights_head_scale, - index_weights_out_ptr, - index_weights_out_stride, - USE_FMA: tl.constexpr, -): - index_q_rot_dim: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM - index_q_nope_dim: tl.constexpr = INDEX_Q_HEAD_DIM - index_q_rot_dim - tl.static_assert(index_q_nope_dim >= 0) - - tok_idx = tl.program_id(0) - head_idx = tl.program_id(1) - - pos = tl.load(pos_ptr + tok_idx) - cos, sin = _get_cos_sin( - index_q_cos_sin_ptr, - index_q_cos_sin_stride, - pos, - INDEX_Q_HALF_ROT_DIM, - ) - half_offset = tl.arange(0, INDEX_Q_HALF_ROT_DIM) - base_ptr = index_q_ptr + tok_idx * index_q_stride0 + head_idx * index_q_stride1 - - rot_base = base_ptr + index_q_nope_dim - x_even = tl.load(rot_base + half_offset * 2).to(tl.float32) - x_odd = tl.load(rot_base + half_offset * 2 + 1).to(tl.float32) - if USE_FMA: - r_even = tl.fma(x_even, cos, -(x_odd * sin)) - r_odd = tl.fma(x_odd, cos, x_even * sin) - else: - r_even = x_even * cos - x_odd * sin - r_odd = x_odd * cos + x_even * sin - - r_even = r_even.to(tl.bfloat16).to(tl.float32) - r_odd = r_odd.to(tl.bfloat16).to(tl.float32) - - amax = tl.maximum(tl.max(tl.abs(r_even)), tl.max(tl.abs(r_odd))) - if index_q_nope_dim > 0: - nope_offset = tl.arange(0, index_q_nope_dim) - x_nope = tl.load(base_ptr + nope_offset).to(tl.float32) - amax = tl.maximum(amax, tl.max(tl.abs(x_nope))) - index_q_scale = tl.div_rn(tl.maximum(amax, 1e-4), 448.0) - index_q_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(index_q_scale))) - - fp8_base_ptr = ( - index_q_fp8_ptr + tok_idx * index_q_fp8_stride0 + head_idx * index_q_fp8_stride1 - ) - if index_q_nope_dim > 0: - tl.store( - fp8_base_ptr + nope_offset, - tl.div_rn(x_nope, index_q_scale).to(tl.float8e4nv), - ) - fp8_rot_base = fp8_base_ptr + index_q_nope_dim - tl.store( - fp8_rot_base + half_offset * 2, - tl.div_rn(r_even, index_q_scale).to(tl.float8e4nv), - ) - tl.store( - fp8_rot_base + half_offset * 2 + 1, - tl.div_rn(r_odd, index_q_scale).to(tl.float8e4nv), - ) - - index_weights = tl.load( - index_weights_ptr + tok_idx * index_weights_stride + head_idx - ) - index_weights = index_weights.to(tl.float32) - index_weights *= index_q_scale - index_weights *= index_weights_softmax_scale - index_weights *= index_weights_head_scale - tl.store( - index_weights_out_ptr + tok_idx * index_weights_out_stride + head_idx, - index_weights, - ) - - -def launch_variant( - positions: torch.Tensor, - index_q: torch.Tensor, - index_q_cos_sin_cache: torch.Tensor, - index_weights: torch.Tensor, - index_weights_softmax_scale: float, - index_weights_head_scale: float, - index_q_fp8: torch.Tensor, - index_weights_out: torch.Tensor, - *, - use_fma: bool, -) -> None: - """Launch one FP8 fused indexer-Q variant into preallocated outputs.""" - assert positions.ndim == 1 - assert index_q.ndim == 3 - assert index_q_cos_sin_cache.ndim == 2 - assert index_q.shape[-1] == HEAD_DIM - assert index_q_cos_sin_cache.shape[-1] == ROPE_DIM - - num_tokens = positions.shape[0] - num_index_q_heads = index_q.shape[1] - _fused_indexer_q_rope_quant_variant_kernel[(num_tokens, num_index_q_heads)]( - positions, - index_q, - index_q.stride(0), - index_q.stride(1), - index_q_cos_sin_cache, - index_q_cos_sin_cache.stride(0), - index_q_cos_sin_cache.shape[-1] // 2, - index_q_fp8, - index_q_fp8.stride(0), - index_q_fp8.stride(1), - index_q.shape[-1], - index_weights, - index_weights.stride(0), - index_weights_softmax_scale, - index_weights_head_scale, - index_weights_out, - index_weights_out.stride(0), - USE_FMA=use_fma, - num_warps=1, - ) - - -def run_variant( - positions: torch.Tensor, - index_q: torch.Tensor, - index_q_cos_sin_cache: torch.Tensor, - index_weights: torch.Tensor, - index_weights_softmax_scale: float, - index_weights_head_scale: float, - *, - use_fma: bool, -) -> tuple[torch.Tensor, torch.Tensor]: - """Run one variant and allocate outputs like production FP8 path.""" - index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) - index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) - launch_variant( - positions, - index_q, - index_q_cos_sin_cache, - index_weights, - index_weights_softmax_scale, - index_weights_head_scale, - index_q_fp8, - index_weights_out, - use_fma=use_fma, - ) - return index_q_fp8, index_weights_out From 528835830f1d8f38dbeb8b73caa73a278ce21561 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 18 May 2026 09:50:32 -0500 Subject: [PATCH 7/9] remove indices Signed-off-by: tjtanaa --- .../test_rocm_fp8_fp4_paged_mqa_logits.py | 32 ++++++++----------- vllm/utils/deep_gemm.py | 15 +-------- .../ops/rocm_fp8_fp4_paged_mqa_logits.py | 14 +++----- 3 files changed, 19 insertions(+), 42 deletions(-) diff --git a/tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py b/tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py index b0e817396368..3108a5b8ad3a 100644 --- a/tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py +++ b/tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py @@ -83,9 +83,7 @@ def _deepgemm_paged_mqa_cases() -> list[PagedMQACase]: DEEPGEMM_PAGED_MQA_CASES = _deepgemm_paged_mqa_cases() -FULL_DEEPGEMM_SHAPES = os.getenv( - "VLLM_ROCM_PAGED_MQA_FULL_DEEPGEMM_SHAPES", "0" -) == "1" +FULL_DEEPGEMM_SHAPES = os.getenv("VLLM_ROCM_PAGED_MQA_FULL_DEEPGEMM_SHAPES", "0") == "1" def _scaled_case_dims(case: PagedMQACase) -> tuple[int, int, int]: @@ -272,11 +270,14 @@ def test_rocm_fp4_mqa_logits_matches_reference() -> None: seq_len_kv = 19 num_heads = 8 head_dim = 128 - q = torch.randn( - (seq_len, num_heads, head_dim), - device=device, - dtype=torch.bfloat16, - ) * 0.125 + q = ( + torch.randn( + (seq_len, num_heads, head_dim), + device=device, + dtype=torch.bfloat16, + ) + * 0.125 + ) k = torch.randn((seq_len_kv, head_dim), device=device, dtype=torch.bfloat16) k = k * 0.125 weights = torch.randn((seq_len, num_heads), device=device, dtype=torch.float32) @@ -407,13 +408,10 @@ def test_rocm_fp8_fp4_paged_mqa_logits_deepgemm_cases(case: PagedMQACase) -> Non device=device, dtype=torch.int32, ) - indices = torch.arange(raw_batch_size, device=device, dtype=torch.int32) - indices = indices.repeat_interleave(tokens_per_seq) batch_size = int(tokens_per_seq.sum().item()) next_n = 1 else: tokens_per_seq = None - indices = None batch_size = raw_batch_size next_n = raw_next_n @@ -447,9 +445,9 @@ def test_rocm_fp8_fp4_paged_mqa_logits_deepgemm_cases(case: PagedMQACase) -> Non else: max_ctx_len_per_seq = context_lens - num_blocks_per_query = torch.ceil( - max_ctx_len_per_seq.float() / case.block_kv - ).to(torch.int32) + num_blocks_per_query = torch.ceil(max_ctx_len_per_seq.float() / case.block_kv).to( + torch.int32 + ) total_used_blocks = int(num_blocks_per_query.sum().item()) num_total_blocks = total_used_blocks + 8 kv_cache = torch.randn( @@ -499,8 +497,8 @@ def test_rocm_fp8_fp4_paged_mqa_logits_deepgemm_cases(case: PagedMQACase) -> Non ), q_scales_u8.view(torch.int32).squeeze(-1), ) - q_simulated = _dequantize_mxfp4(q_packed, q_scales_u8).view_as(q).to( - torch.bfloat16 + q_simulated = ( + _dequantize_mxfp4(q_packed, q_scales_u8).view_as(q).to(torch.bfloat16) ) kv_in, kv_simulated = _kv_cache_cast_to_fp4(kv_cache) else: @@ -512,7 +510,6 @@ def test_rocm_fp8_fp4_paged_mqa_logits_deepgemm_cases(case: PagedMQACase) -> Non context_lens_nextn, case.block_kv, get_num_sms(), - indices=indices, ) assert schedule_metadata.shape == (get_num_sms() + 1, 2) assert schedule_metadata.dtype == torch.int32 @@ -527,7 +524,6 @@ def test_rocm_fp8_fp4_paged_mqa_logits_deepgemm_cases(case: PagedMQACase) -> Non max_model_len, clean_logits=False, logits_dtype=case.logits_dtype, - indices=indices, ) assert logits.dtype == case.logits_dtype diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index ffb03ae2303e..3efb5f3c781c 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -403,7 +403,6 @@ def get_paged_mqa_logits_metadata( context_lens: torch.Tensor, block_size: int, num_sms: int, - indices: torch.Tensor | None = None, ) -> torch.Tensor: """Build scheduling metadata for paged MQA logits. @@ -412,8 +411,6 @@ def get_paged_mqa_logits_metadata( effective context length per batch element or per decoded token. block_size: KV-cache block size in tokens (e.g., 64). num_sms: Number of SMs/CUs available. - indices: Optional varlen token-to-sequence indices for DeepGEMM SM100 - style scheduling. ROCm accepts this for API compatibility. Returns: Backend-specific tensor consumed by `fp8_fp4_paged_mqa_logits` to @@ -428,17 +425,12 @@ def get_paged_mqa_logits_metadata( context_lens, block_size, num_sms, - indices=indices, ) _lazy_init() if _get_paged_mqa_logits_metadata_impl is None: return _missing() - if indices is None: - return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) - return _get_paged_mqa_logits_metadata_impl( - context_lens, block_size, num_sms, indices=indices - ) + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) def fp8_fp4_paged_mqa_logits( @@ -451,7 +443,6 @@ def fp8_fp4_paged_mqa_logits( max_model_len: int, clean_logits: bool, logits_dtype: torch.dtype = torch.float32, - indices: torch.Tensor | None = None, ) -> torch.Tensor: """Compute MQA logits using a paged KV-cache. @@ -478,7 +469,6 @@ def fp8_fp4_paged_mqa_logits( max_model_len: Maximum sequence length used to size the logits output. clean_logits: Whether to clean the unfilled logits into `-inf`. logits_dtype: Output dtype, matching DeepGEMM's float32/bfloat16 API. - indices: Optional varlen token-to-sequence indices. Returns: Logits tensor of shape [B * next_n, max_model_len], dtype @@ -499,7 +489,6 @@ def fp8_fp4_paged_mqa_logits( max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype, - indices=indices, ) _lazy_init() @@ -508,8 +497,6 @@ def fp8_fp4_paged_mqa_logits( kwargs: dict[str, Any] = {"clean_logits": clean_logits} if logits_dtype is not torch.float32: kwargs["logits_dtype"] = logits_dtype - if indices is not None: - kwargs["indices"] = indices return _fp8_fp4_paged_mqa_logits_impl( q, kv_cache, diff --git a/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py b/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py index 290f1b020af3..daabb808e5c6 100644 --- a/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py +++ b/vllm/v1/attention/ops/rocm_fp8_fp4_paged_mqa_logits.py @@ -113,9 +113,7 @@ def _fp8_paged_mqa_logits_kernel( block_table_ptr + batch * block_table_stride_b + logical_block ) block_base = kv_cache_ptr + physical_block.to(tl.int64) * KV_CACHE_STRIDE_B - scale_base = (block_base + KV_BLOCK_SIZE * HEAD_DIM).to( - tl.pointer_type(tl.float32) - ) + scale_base = (block_base + KV_BLOCK_SIZE * HEAD_DIM).to(tl.pointer_type(tl.float32)) h = tl.arange(0, NUM_HEADS) d = tl.arange(0, HEAD_DIM) @@ -233,9 +231,7 @@ def _fp4_paged_mqa_logits_kernel( other=0, ) block_base = kv_cache_ptr + physical_block.to(tl.int64) * KV_CACHE_STRIDE_B - scale_base = (block_base + KV_BLOCK_SIZE * value_dim).to( - tl.pointer_type(tl.int32) - ) + scale_base = (block_base + KV_BLOCK_SIZE * value_dim).to(tl.pointer_type(tl.int32)) q_packed = tl.load( q_ptr @@ -296,7 +292,6 @@ def rocm_get_paged_mqa_logits_metadata( context_lens: torch.Tensor, block_size: int, num_sms: int, - indices: torch.Tensor | None = None, ) -> torch.Tensor: """Return a DeepGEMM-compatible metadata tensor for ROCm. @@ -305,7 +300,7 @@ def rocm_get_paged_mqa_logits_metadata( same ``[num_sms + 1, 2]`` int32 shape keeps callers compatible with DeepGEMM's API and shape checks. """ - del block_size, indices + del block_size return torch.empty( (int(num_sms) + 1, 2), dtype=torch.int32, device=context_lens.device ) @@ -548,10 +543,9 @@ def rocm_fp8_fp4_paged_mqa_logits( max_context_len: int, clean_logits: bool = False, logits_dtype: torch.dtype = torch.float32, - indices: torch.Tensor | None = None, ) -> torch.Tensor: """Compute DeepGEMM-compatible paged MQA logits on ROCm.""" - del schedule_metadata, indices + del schedule_metadata q_values, q_scales = q is_fp4 = q_scales is not None From 99fd8f902a4837c682e550fe97d2013a7f7c7a24 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 18 May 2026 10:10:03 -0500 Subject: [PATCH 8/9] update comment Signed-off-by: tjtanaa --- vllm/config/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config/attention.py b/vllm/config/attention.py index c446f92c56f0..cf50d28c81ba 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -48,7 +48,8 @@ class AttentionConfig: """If set, quantize query for attention in prefill.""" use_fp4_indexer_cache: bool = False - """If set, use fp4 indexer cache for dsv32 family models. + """If set, use fp4 indexer cache for dsv32 and dsv4 family models. + But fp4 indexer cache is not supported for dsv32 family models. Supported on CUDA SM100 datacenter GPUs and ROCm gfx95x GPUs.""" use_non_causal: bool = False From 490a7da304c576b4f0ed0438efd9f1c868b7f59d Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 18 May 2026 11:52:19 -0500 Subject: [PATCH 9/9] remove the use of torpk torch Signed-off-by: tjtanaa --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 161 +++--------------- 1 file changed, 25 insertions(+), 136 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 62d7e6955baf..af6ae4710fc3 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -617,45 +617,6 @@ def rocm_fp8_mqa_logits( return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) -def _topk_indices_torch( - logits: torch.Tensor, - topk_tokens: int, - row_starts: torch.Tensor | None = None, -) -> torch.Tensor: - k = min(topk_tokens, logits.shape[-1]) - values, indices = torch.topk(logits, k=k, dim=-1) - indices = indices.to(torch.int32) - indices = torch.where( - values == float("-inf"), - torch.full_like(indices, -1, dtype=torch.int32), - indices, - ) - if row_starts is not None: - # Match the CUDA top_k_per_row_prefill contract: indices are local to - # each row's valid [row_start, row_end) range, not columns in the - # concatenated chunk logits matrix. - starts = row_starts.to(dtype=torch.int32).view(-1, 1) - indices = torch.where(indices < 0, indices, indices - starts) - if k == topk_tokens: - return indices - padded = torch.full( - (logits.shape[0], topk_tokens), - -1, - dtype=torch.int32, - device=logits.device, - ) - padded[:, :k] = indices - return padded - - -def _has_vllm_topk_ops() -> bool: - return ( - hasattr(torch.ops, "_C") - and hasattr(torch.ops._C, "top_k_per_row_prefill") - and hasattr(torch.ops._C, "top_k_per_row_decode") - ) - - @triton.jit def _fill_prefill_topk_indices_kernel( row_starts_ptr, @@ -762,74 +723,6 @@ def _fill_decode_topk_indices( return indices -def _topk_indices_vllm_prefill( - logits: torch.Tensor, - topk_tokens: int, - row_starts: torch.Tensor, - row_ends: torch.Tensor, - out: torch.Tensor | None = None, -) -> torch.Tensor: - if not _has_vllm_topk_ops(): - indices = _topk_indices_torch(logits, topk_tokens, row_starts) - if out is not None: - out.copy_(indices) - return out - return indices - - indices = out - if indices is None: - indices = torch.empty( - (logits.shape[0], topk_tokens), - dtype=torch.int32, - device=logits.device, - ) - torch.ops._C.top_k_per_row_prefill( - logits, - row_starts, - row_ends, - indices, - logits.shape[0], - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - return indices - - -def _topk_indices_vllm_decode( - logits: torch.Tensor, - topk_tokens: int, - next_n: int, - seq_lens: torch.Tensor, - out: torch.Tensor | None = None, -) -> torch.Tensor: - if not _has_vllm_topk_ops(): - indices = _topk_indices_torch(logits, topk_tokens) - if out is not None: - out.copy_(indices) - return out - return indices - - indices = out - if indices is None: - indices = torch.empty( - (logits.shape[0], topk_tokens), - dtype=torch.int32, - device=logits.device, - ) - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - seq_lens, - indices, - logits.shape[0], - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - return indices - - def rocm_aiter_sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, @@ -859,9 +752,7 @@ def rocm_aiter_sparse_attn_indexer_fake( ) device = hidden_states.device if k is None else k.device flattened_width = ( - head_dim // 2 + head_dim // MXFP4_BLOCK_SIZE - if use_fp4_cache - else head_dim + 4 + head_dim // 2 + head_dim // MXFP4_BLOCK_SIZE if use_fp4_cache else head_dim + 4 ) _flattened_kv = torch.empty( [total_seq_lens, flattened_width], device=device, dtype=torch.uint8 @@ -1055,7 +946,6 @@ def rocm_aiter_sparse_attn_indexer_native( weight_slice = weights[chunk.token_start : chunk.token_end] if use_fp4_cache: assert q_scale is not None - use_vllm_topk = _has_vllm_topk_ops() q_scale_slice = q_scale[chunk.token_start : chunk.token_end] logits = rocm_fp4_mqa_logits( (q_slice.view(torch.int8), q_scale_slice), @@ -1066,7 +956,7 @@ def rocm_aiter_sparse_attn_indexer_native( weight_slice, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, - clean_logits=not use_vllm_topk, + clean_logits=False, ) else: logits = rocm_fp8_mqa_logits( @@ -1076,22 +966,16 @@ def rocm_aiter_sparse_attn_indexer_native( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - if use_fp4_cache: - _topk_indices_vllm_prefill( - logits, - topk_tokens, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - out=topk_indices, - ) - else: - _topk_indices_vllm_prefill( - logits, - topk_tokens, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - out=topk_indices, - ) + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + logits.shape[0], + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if has_decode: decode_metadata = layer_attn_metadata.decode @@ -1169,7 +1053,6 @@ def rocm_aiter_sparse_attn_indexer_native( out=topk_indices, ) else: - use_vllm_topk = _has_vllm_topk_ops() active_paged_width = ( decode_metadata.block_table.shape[1] * kv_cache_decode.shape[1] ) @@ -1185,14 +1068,17 @@ def rocm_aiter_sparse_attn_indexer_native( decode_metadata.block_table, decode_metadata.schedule_metadata, max_model_len=logits_width, - clean_logits=not use_vllm_topk, + clean_logits=False, ) - _topk_indices_vllm_decode( + torch.ops._C.top_k_per_row_decode( logits, - topk_tokens, next_n, decode_metadata.seq_lens, - out=topk_indices, + topk_indices, + logits.shape[0], + logits.stride(0), + logits.stride(1), + topk_tokens, ) else: logits = rocm_fp8_paged_mqa_logits( @@ -1204,12 +1090,15 @@ def rocm_aiter_sparse_attn_indexer_native( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) - _topk_indices_vllm_decode( + torch.ops._C.top_k_per_row_decode( logits, - topk_tokens, next_n, decode_metadata.seq_lens, - out=topk_indices, + topk_indices, + logits.shape[0], + logits.stride(0), + logits.stride(1), + topk_tokens, ) if decode_metadata.requires_padding: