diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index bdbe46ad9177..4a2e909d50ba 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -216,7 +216,7 @@ configuration. | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 54d2cb53b0d1..15913c418b05 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -674,30 +674,45 @@ def forward( ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) - q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) - # Fused wk + weights_proj: one GEMM, then split - kw, _ = self.wk_weights_proj(hidden_states) - k = kw[:, : self.head_dim] - weights = kw[:, self.head_dim :] - - k = self.k_norm(k) - k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) - q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation - # so we need to reshape back to token-flattened shapes - q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) - k_pe = k_pe.reshape(-1, 1, self.rope_dim) - - # `rotary_emb` is shape-preserving; `q_pe` is already - # [num_tokens, n_head, rope_dim]. - q = torch.cat([q_pe, q_nope], dim=-1) - # `k_pe` is [num_tokens, 1, rope_dim] (MQA). - k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) + if current_platform.is_rocm(): + # This path should works on all platform, will remove extra + # branches in the future + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] + + k = self.k_norm(k) + + rotary_emb( + positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1) + ) + else: + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] + + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + # Note: RoPE (NeoX) can introduce extra leading dimensions during + # compilation so we need to reshape back to token-flattened shapes + q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) + k_pe = k_pe.reshape(-1, 1, self.rope_dim) + + # `rotary_emb` is shape-preserving; `q_pe` is already + # [num_tokens, n_head, rope_dim]. + q = torch.cat([q_pe, q_nope], dim=-1) + # `k_pe` is [num_tokens, 1, rope_dim] (MQA). + k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 5d12d27e7625..7c0715a9e8b6 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -122,7 +122,7 @@ def get_name() -> str: @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1 if current_platform.is_rocm() else 64] + return [1, 64] if current_platform.is_rocm() else [64] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index a66a97311fbc..2106226118ef 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -396,6 +396,7 @@ class AiterMLAHelper: """ _AITER_MIN_MLA_HEADS: Final = 16 + _AITER_UNSUPPORTED_HEADS = [32] @staticmethod def check_num_heads_validity(num_heads: int): @@ -419,6 +420,9 @@ def get_actual_mla_num_heads(num_heads: int) -> int: @staticmethod def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor: + assert num_heads not in AiterMLAHelper._AITER_UNSUPPORTED_HEADS, ( + f"unsupported head_num: {num_heads}" + ) return ( q if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 503bb509b105..dc343b639f6c 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -7,6 +7,7 @@ import numpy as np import torch +from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.cache import CacheDType @@ -14,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( get_mla_dims, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -25,9 +27,6 @@ MultipleOf, SparseMLAAttentionImpl, ) -from vllm.v1.attention.backends.mla.flashmla_sparse import ( - triton_convert_req_index_to_global_index, -) from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( AiterMLAHelper, ) @@ -38,6 +37,188 @@ logger = init_logger(__name__) +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens_ptr, # int32 [num_tokens + 1] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load cumulative sequence lengths to get starting index of this request + seq_start = tl.load(cu_seqlens_ptr + token_id) + seq_end = tl.load(cu_seqlens_ptr + token_id + 1) + + if tile_id * BLOCK_N + seq_start >= seq_end: + return + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # # If token == -1 OR block_id OOB, output 0; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), 0, base * BLOCK_SIZE + inblock_off + ) + out_ptr_ij = out_ptr + seq_start + indice_id + out_ptr_ij_mask = (seq_start + indice_id) < seq_end + + # store the results with mask + tl.store(out_ptr_ij, out_val, mask=out_ptr_ij_mask) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens: torch.Tensor, # int32 [num_tokens + 1] + paged_kv_indices: torch.Tensor, # int32 [num_tokens * topk] out_buffer + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) + # print("req_id: ", req_id, flush=True) + num_tokens = req_id.shape[0] + _, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + cu_seqlens, + paged_kv_indices, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + ) + return + + +@triton.jit +def generate_sparse_seqlen_kernel( + seq_len_ptr, # [num_seq] + cu_query_lens_ptr, # [num_seq] + out_ptr, # [num_query_tokens] + topk_token: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + query_offset = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + query_start = tl.load(cu_query_lens_ptr + seq_id) + query_end = tl.load(cu_query_lens_ptr + seq_id + 1) + if query_start + tl.program_id(1) * BLOCK_SIZE > query_end: + return + query_len = query_end - query_start + query_mask = query_offset + query_start < query_end + seq_len = tl.load(seq_len_ptr + seq_id) + # Just return since the out_ptr is zero initialized. + if seq_len == 0: + return + context_start_point = seq_len - query_len + sparse_seqlen = context_start_point + query_offset + sparse_seqlen_masked = tl.where( + sparse_seqlen + 1 < topk_token, sparse_seqlen + 1, topk_token + ) + tl.store( + out_ptr + query_start + query_offset, sparse_seqlen_masked, mask=query_mask + ) + + +def generate_sparse_seqlen_triton( + query_lens: torch.Tensor, + seq_lens: torch.Tensor, + cu_query_lens: torch.Tensor, + topk_token: int, + num_tokens: int, + max_query_len: int, +): + num_seqs = query_lens.size(0) + # zero initialize the tensor to make sure invalid positions will be zero + out = torch.zeros([num_tokens], dtype=torch.int32, device=query_lens.device) + block_size = 64 + num_block_per_row = triton.cdiv(max_query_len, block_size) + grid = ( + num_seqs, + num_block_per_row, + ) + generate_sparse_seqlen_kernel[grid]( + seq_lens, + cu_query_lens, + out, + topk_token, + block_size, + ) + return out + + @triton.jit def fetch_id_to_ragged_kernel( in_tensor_ptr, # [num_seq, topk] @@ -86,11 +267,13 @@ class ROCMAiterMLASparseBackend(AttentionBackend): "auto", "float16", "bfloat16", + "fp8", + "fp8_e4m3", ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1] + return [1, 64] @staticmethod def get_name() -> str: @@ -144,7 +327,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): paged_kv_last_page_len: torch.Tensor paged_kv_indices: torch.Tensor paged_kv_indptr: torch.Tensor - paged_kv_indptr_rest: torch.Tensor + attn_out_dtype: torch.dtype block_size: int = 1 topk_tokens: int = 2048 @@ -167,6 +350,7 @@ def __init__( ): self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config + self.model_dtype = vllm_config.model_config.dtype parallel_config = vllm_config.parallel_config self.device = device max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -174,9 +358,6 @@ def __init__( self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk - self.topk_tokens_tensor = torch.tensor( - [self.topk_tokens], device=device, dtype=torch.int32 - ) self.max_model_len_tensor = torch.tensor( [self.model_config.max_model_len], device=device, dtype=torch.int32 ) @@ -222,18 +403,33 @@ def build( ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) + self.paged_kv_indices.fill_(0) + self.paged_kv_indptr.fill_(0) self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) - self.paged_kv_indices.fill_(0) - self.paged_kv_indptr.fill_(0) + query_lens = ( + common_attn_metadata.query_start_loc[1:] + - common_attn_metadata.query_start_loc[:-1] + ) + seq_lens = common_attn_metadata.seq_lens + sparse_seqlen = generate_sparse_seqlen_triton( + query_lens, + seq_lens, + common_attn_metadata.query_start_loc, + self.topk_tokens, + num_tokens, + common_attn_metadata.max_query_len, + ) + + torch.cumsum(sparse_seqlen, dim=0, out=self.paged_kv_indptr[1 : num_tokens + 1]) + self.paged_kv_indptr[num_tokens + 1 :].fill_(self.paged_kv_indptr[num_tokens]) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] qo_indptr = self.qo_indptr[: num_tokens + 1] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens] - paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] - paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] + paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -245,12 +441,12 @@ def build( block_table=common_attn_metadata.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, + attn_out_dtype=self.model_dtype, topk_tokens=self.topk_tokens, qo_indptr=qo_indptr, paged_kv_last_page_len=paged_kv_last_page_len, paged_kv_indices=paged_kv_indices, paged_kv_indptr=paged_kv_indptr, - paged_kv_indptr_rest=paged_kv_indptr_rest, ) return metadata @@ -314,29 +510,20 @@ def __init__( assert indexer is not None self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer - def _forward_bf16_kv( + def _forward_mla( self, + layer: AttentionLayer, q: torch.Tensor, # [sq, heads, d_qk] kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] - topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: num_tokens = q.shape[0] mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads) output = torch.empty( [num_tokens, mla_num_heads, self.kv_lora_rank], - dtype=q.dtype, + dtype=attn_metadata.attn_out_dtype, device=q.device, ) - seq_len = (topk_indices != -1).sum(dim=-1) - torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) - attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) - fetch_id_to_ragged_triton( - topk_indices, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.topk_tokens, - ) rocm_aiter_ops.mla_decode_fwd( q, @@ -348,6 +535,8 @@ def _forward_bf16_kv( attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_len, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output) @@ -366,23 +555,32 @@ def forward_mqa( if isinstance(q, tuple): q = torch.cat(q, dim=-1) - num_actual_toks = q.shape[0] + num_actual_toks = attn_metadata.num_actual_tokens # Get topk indices assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] - topk_indices_global = triton_convert_req_index_to_global_index( + triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) + # write the latent and rope to kv cache + fp8_attention = self.kv_cache_dtype.startswith("fp8") + if fp8_attention: + original_q_shape = q.shape + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(current_platform.fp8_dtype()) + q, _ = ops.scaled_fp8_quant(q.view(q.shape[0], -1), layer._q_scale) + q = q.view(original_q_shape) mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q) - attn_out = self._forward_bf16_kv( - mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata + attn_out = self._forward_mla( + layer, mla_padded_q, kv_c_and_k_pe_cache, attn_metadata ) return attn_out, None diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 81cc489db0d8..627d870b62ff 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -13,9 +13,6 @@ from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops - @triton.jit def _indexer_k_quant_and_cache_kernel( @@ -97,7 +94,8 @@ def indexer_k_quant_and_cache_triton( # In real layout, we store the first portion as kv cache value # and second portion as kv cache scale kv_cache = kv_cache.view(num_blocks, -1) - kv_cache_value = kv_cache[:, : block_size * head_dim] + fp8_dtype = current_platform.fp8_dtype() + kv_cache_value = kv_cache[:, : block_size * head_dim].view(fp8_dtype) kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) head_tile_size = head_tile_size // kv_cache.element_size() grid = (num_tokens,) @@ -111,7 +109,7 @@ def indexer_k_quant_and_cache_triton( block_size, num_tokens, head_dim, - "NHD", + "SHUFFLE", block_tile_size, head_tile_size, IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, @@ -212,7 +210,7 @@ def cp_gather_indexer_k_quant_cache_triton( block_table_stride, k_cache_value.stride(0), k_cache_scale.stride(0), - "NHD", + "SHUFFLE", head_dim, block_tile_size, head_tile_size, @@ -325,33 +323,38 @@ def rocm_fp8_paged_mqa_logits( from vllm._aiter_ops import rocm_aiter_ops aiter_paged_mqa_logits_module = None + # if rocm_aiter_ops.is_enabled(): + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() if aiter_paged_mqa_logits_module is not None: - deepgemm_fp8_paged_mqa_logits_stage1 = ( - aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 + deepgemm_fp8_paged_mqa_logits = ( + aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits ) batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), + out_logits = torch.full( + [batch_size * next_n, max_model_len], float("-inf"), device="cuda", dtype=torch.float32, ) - # TODO: 1. Replace _stage1 and out_qk.sum with another fused variant; - # 2. Remove ChunkQ when AITER PR #2891 merged - deepgemm_fp8_paged_mqa_logits_stage1( + deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, - out_qk, + out_logits, context_lens, block_tables, max_model_len, - ChunkQ=heads, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, ) - return out_qk.sum(dim=0) + return out_logits else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len @@ -540,7 +543,7 @@ def rocm_aiter_sparse_attn_indexer( num_tokens = slot_mapping.shape[0] k = k[:num_tokens] - ops.indexer_k_quant_and_cache( + indexer_k_quant_and_cache_triton( k, kv_cache, slot_mapping, @@ -563,13 +566,13 @@ def rocm_aiter_sparse_attn_indexer( device=k.device, dtype=torch.uint8, ) - - ops.cp_gather_indexer_k_quant_cache( + cp_gather_indexer_k_quant_cache_triton( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, + chunk.token_to_seq, ) logits = rocm_fp8_mqa_logits(