From 9b5e76db05b42895708c0bff22aac78de990dde9 Mon Sep 17 00:00:00 2001 From: chenjun Date: Tue, 28 Apr 2026 05:17:34 -0500 Subject: [PATCH 1/2] dsa: remove block_table_convert_triton in dsa --- atom/model_ops/attention_mla.py | 14 ++++++++++---- atom/model_ops/attentions/aiter_mla.py | 26 +++++++++++++------------- atom/models/deepseek_v2.py | 8 +++++++- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 89a74f985..798a9f67d 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -484,6 +484,7 @@ def _forward_prefill_mla( attn_metadata.block_tables, attn_metadata.cu_seqlens_k, NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], + PAGE_SIZE=get_current_atom_config().kv_cache_block_size, ) paged_cu_seqlens_q = attn_metadata.sparse_cu_seqlens_q paged_kv_indptr = attn_metadata.sparse_kv_indptr @@ -916,7 +917,7 @@ def _convert_req_index_to_global_index_dsa_prefill_kernel( out_kv_indices, # int32 # shapes (compile-time where possible) NUM_TOPK_TOKENS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + PAGE_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, # tile width along columns # strides (in elements) ti_stride0: tl.int64, # topk_indices stride 0 @@ -941,14 +942,19 @@ def _convert_req_index_to_global_index_dsa_prefill_kernel( ) # int32 pre_seqlens_q = tl.load(cu_seqlens_q + req_id) + seq_token_idx = indice - pre_seqlens_q + block_id = seq_token_idx // PAGE_SIZE + inblock_offset = seq_token_idx % PAGE_SIZE + # Guard block_table access store_mask = (col_id < kv_len) & (col_id < NUM_TOPK_TOKENS) valid_mask = store_mask & (indice >= 0) - out_val = tl.load( - block_table + req_id * bt_stride0 + (indice - pre_seqlens_q) * bt_stride1, + physical_block = tl.load( + block_table + req_id * bt_stride0 + block_id * bt_stride1, mask=valid_mask, other=-1, ) + out_val = tl.where(valid_mask, physical_block * PAGE_SIZE + inblock_offset, -1) # Store results out_ptr_ij = out_kv_indices + kv_start + col_id @@ -967,7 +973,7 @@ def triton_convert_req_index_to_global_index_dsa_prefill( block_table: torch.Tensor, # int32 [num_req, max_num_blocks_per_req] cu_seqlens_q: torch.Tensor, # int32 [num_tokens + 1] # dsa_kv_indices: torch.Tensor, # int32 [total_kv_seqlen] -->>> output for this kernel - PAGE_SIZE: int = 1, # page_block_size = 1 for now + PAGE_SIZE: int = 1, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 1024, # tile width along columns ): diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 3f895df29..e3050018e 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -381,14 +381,14 @@ def prepare_prefill(self, batch: ScheduledBatch): if attn_metadata.block_tables is None: self.prepare_block_tables(batch) attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) - if self.block_ratio > 1: - block_table_convert_triton( - var["block_tables"].gpu[:bs], - var["block_tables_converted"].gpu[:bs], - var["context_lens"].gpu[:bs], - self.block_ratio, - ) - attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] + # if self.block_ratio > 1: + # block_table_convert_triton( + # var["block_tables"].gpu[:bs], + # var["block_tables_converted"].gpu[:bs], + # var["context_lens"].gpu[:bs], + # self.block_ratio, + # ) + # attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] counts = var["cu_seqlens_q"].np[1 : bs + 1] - var["cu_seqlens_q"].np[:bs] if attn_metadata.has_cached: # Full context (cached + new): use cu_seqlens_k for indexer @@ -787,11 +787,11 @@ def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData: kv_indices=var["kv_indices"].gpu, kv_last_page_lens=var["kv_last_page_lens"].gpu[:bs], sparse_kv_indptr=sparse_kv_indptr, - block_tables_converted=( - var["block_tables_converted"].gpu[:bs] - if "block_tables_converted" in var - else None - ), + # block_tables_converted=( + # var["block_tables_converted"].gpu[:bs] + # if "block_tables_converted" in var + # else None + # ), **ctx_mla_ps, ) attn_matadata.dtype_q = self.dtype_q diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index a04dcfb20..e7f0f2434 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1012,12 +1012,15 @@ def sparse_attn_indexer( # dummy runner return weights num_decode_tokens = context.batch_size if not context.is_prefill else 0 + ori_block_size = get_current_atom_config().kv_cache_block_size + kv_cache = kv_cache.view(-1, ori_block_size, kv_cache.shape[-1]) indexer_k_quant_and_cache( k, kv_cache, slot_mapping, quant_block_size, scale_fmt, + preshuffle=True, ) if context.is_prefill: if attn_metadata.max_seqlen_k <= topk_indices_buffer.shape[1]: @@ -1049,6 +1052,7 @@ def sparse_attn_indexer( if prefill_metadata.has_cached else prefill_metadata.cu_seqlens_q ), + preshuffle=True, ) cu_seqlen_ks = prefill_metadata.cu_seqlen_ks cu_seqlen_ke = prefill_metadata.cu_seqlen_ke @@ -1100,6 +1104,8 @@ def sparse_attn_indexer( decode_metadata.context_lens, attn_metadata.block_tables, max_model_len, + KVBlockSize=kv_cache.shape[1], + Preshuffle=True, ) num_rows = logits.shape[0] assert topk_tokens == 2048, "top_k_per_row assumes size 2048" @@ -1192,7 +1198,7 @@ def __init__( self.weights_proj = ReplicatedLinear( hidden_size, self.n_head, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.weights_proj", ) self.softmax_scale = self.head_dim**-0.5 From ed1bec05a349c7e6668244d87374fa4c58ae11be Mon Sep 17 00:00:00 2001 From: chenjun Date: Thu, 30 Apr 2026 01:28:44 -0500 Subject: [PATCH 2/2] clear code --- atom/model_ops/attentions/aiter_mla.py | 22 ---------------------- atom/models/deepseek_v2.py | 6 +++--- 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 8891770f3..84058585e 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -489,14 +489,6 @@ def prepare_prefill(self, batch: ScheduledBatch): if attn_metadata.block_tables is None: self.prepare_block_tables(batch) attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) - # if self.block_ratio > 1: - # block_table_convert_triton( - # var["block_tables"].gpu[:bs], - # var["block_tables_converted"].gpu[:bs], - # var["context_lens"].gpu[:bs], - # self.block_ratio, - # ) - # attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] counts = var["cu_seqlens_q"].np[1 : bs + 1] - var["cu_seqlens_q"].np[:bs] if attn_metadata.has_cached: # Full context (cached + new): use cu_seqlens_k for indexer @@ -704,15 +696,6 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): ctx_mla_ps = self.set_mla_persistent_worker_buffers(bs, max_seqlen_q) ctx.update(ctx_mla_ps) current_stream.wait_stream(prep_stream) - # if self.block_ratio > 1: - # if "block_tables" in ctx: - # block_table_convert_triton( - # var["block_tables"].gpu[:bs], - # var["block_tables_converted"].gpu[:bs], - # var["context_lens"].gpu[:bs], - # self.block_ratio, - # ) - # ctx["block_tables_converted"] = var["block_tables_converted"].gpu[:bs] attn_metadata = AttentionMetaData( dropout_p=dropout_p, max_seqlen_q=max_seqlen_q, @@ -895,11 +878,6 @@ def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData: kv_indices=var["kv_indices"].gpu, kv_last_page_lens=var["kv_last_page_lens"].gpu[:bs], sparse_kv_indptr=sparse_kv_indptr, - # block_tables_converted=( - # var["block_tables_converted"].gpu[:bs] - # if "block_tables_converted" in var - # else None - # ), **ctx_mla_ps, ) attn_matadata.dtype_q = self.dtype_q diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index e7f0f2434..04bfd2054 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1012,8 +1012,8 @@ def sparse_attn_indexer( # dummy runner return weights num_decode_tokens = context.batch_size if not context.is_prefill else 0 - ori_block_size = get_current_atom_config().kv_cache_block_size - kv_cache = kv_cache.view(-1, ori_block_size, kv_cache.shape[-1]) + runner_block_size = get_current_atom_config().kv_cache_block_size + kv_cache = kv_cache.view(-1, runner_block_size, kv_cache.shape[-1]) indexer_k_quant_and_cache( k, kv_cache, @@ -1104,7 +1104,7 @@ def sparse_attn_indexer( decode_metadata.context_lens, attn_metadata.block_tables, max_model_len, - KVBlockSize=kv_cache.shape[1], + KVBlockSize=runner_block_size, Preshuffle=True, ) num_rows = logits.shape[0]