Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions atom/model_ops/attentions/deepseek_v4_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
AttentionMetadataBuilder,
CommonAttentionBuilder,
)
from atom.utils import CpuGpuBuffer
from atom.utils.forward_context import AttentionMetaData


Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(self, model_runner):
self.head_dim = getattr(hf, "kv_head_dim", 512)
self.index_head_dim = getattr(hf, "index_head_dim", 128)
self.window_size = getattr(hf, "sliding_window", 128)
self.index_topk = getattr(hf, "index_topk", 1024)

# Compressor state shape: [coff * ratio, coff * head_dim], fp32.
# CSA: ratio=4, overlap=True -> coff=2 -> [8, 2*head_dim]
Expand All @@ -114,6 +116,24 @@ def __init__(self, model_runner):
self._swa_dtype = torch.bfloat16 # SWA window matches KV dtype
self._classical_dtype = torch.bfloat16 # compressed KV is BF16

# Sparse-attn layout metadata, aligned with aiter_mla's CpuGpuBuffer
# convention. Values are per-token and reusable across layers with the
# same compress_ratio; the actual topk indices remain layer-specific.
i32_kwargs = {"dtype": torch.int32, "device": self.device}
i64_kwargs = {"dtype": torch.int64, "device": self.device}
v4_sparse_metadata = {}
for kind in ("dense", "csa", "hca"):
v4_sparse_metadata[f"v4_{kind}_sparse_topk_starts"] = CpuGpuBuffer(
self.max_num_batched_tokens, **i64_kwargs
)
v4_sparse_metadata[f"v4_{kind}_sparse_topk_lens"] = CpuGpuBuffer(
self.max_num_batched_tokens, **i32_kwargs
)
v4_sparse_metadata[f"v4_{kind}_sparse_kv_offsets"] = CpuGpuBuffer(
self.max_num_batched_tokens, **i32_kwargs
)
self.model_runner.forward_vars.update(v4_sparse_metadata)

# ------------------------------------------------------------------ #
# AttentionMetadataBuilder hooks (per-request cache abstraction). #
# ------------------------------------------------------------------ #
Expand Down Expand Up @@ -314,6 +334,92 @@ def build_kv_cache_tensor(self, layer_id: int, module):
# the rest of ATOM and to support PR3-main multi-sequence wiring). #
# ------------------------------------------------------------------ #

def _attach_sparse_layout_metadata(
self,
attn_metadata: AttentionMetaData,
cu_seqlens_q_np,
start_pos_per_seq,
scheduled_bs: int,
total_tokens: int,
) -> None:
"""Precompute per-token ragged sparse-attn layout for each ratio type.

The actual topk index values are layer-specific (CSA Indexer depends on
weights), but the per-token topk span and global-KV offset layout only
depends on the request geometry and compress_ratio.
"""
import numpy as np

var = self.model_runner.forward_vars
layouts = {}
ratio_specs = {
0: ("dense", 0),
4: ("csa", 4),
128: ("hca", 128),
}
for ratio_key, (kind, ratio) in ratio_specs.items():
starts = var[f"v4_{kind}_sparse_topk_starts"].np
lens = var[f"v4_{kind}_sparse_topk_lens"].np
offsets = var[f"v4_{kind}_sparse_kv_offsets"].np
topk_base = 0
kv_base = 0
token_base = 0
max_topk = 0
for seq_idx in range(scheduled_bs):
seq_start = int(cu_seqlens_q_np[seq_idx])
seq_end = int(cu_seqlens_q_np[seq_idx + 1])
token_num = seq_end - seq_start
if token_num == 0:
continue
start_pos = int(start_pos_per_seq[seq_idx])
end_pos = start_pos + token_num
window_topk = self.window_size if start_pos > 0 else min(
token_num, self.window_size
)
compress_topk = 0
kv_len = token_num if start_pos == 0 else self.window_size
if ratio == 4:
compress_topk = min(self.index_topk, end_pos // ratio)
if start_pos == 0:
kv_len = token_num + (token_num + ratio - 1) // ratio
else:
kv_len = self.window_size + end_pos // ratio
elif ratio == 128:
if start_pos > 0:
compress_topk = (start_pos + 1) // ratio
else:
compress_topk = token_num // ratio
if start_pos == 0:
kv_len = token_num + (token_num + ratio - 1) // ratio
else:
kv_len = self.window_size + end_pos // ratio
topk_len = window_topk + compress_topk
starts[token_base : token_base + token_num] = np.arange(
topk_base,
topk_base + token_num * topk_len,
topk_len,
dtype=np.int64,
)
lens[token_base : token_base + token_num] = topk_len
offsets[token_base : token_base + token_num] = kv_base
token_base += token_num
topk_base += token_num * topk_len
kv_base += kv_len
max_topk = max(max_topk, topk_len)

topk_starts = var[f"v4_{kind}_sparse_topk_starts"].copy_to_gpu(
total_tokens
)
topk_lens = var[f"v4_{kind}_sparse_topk_lens"].copy_to_gpu(total_tokens)
kv_offsets = var[f"v4_{kind}_sparse_kv_offsets"].copy_to_gpu(total_tokens)
layouts[ratio_key] = {
"topk_starts": topk_starts,
"topk_lens": topk_lens,
"kv_offsets": kv_offsets,
"max_topk": max_topk,
}
attn_metadata.v4_sparse_layouts = layouts

def prepare_decode(self, batch: ScheduledBatch, bs: int):
"""V4-style decode prep: populates positions, cu_seqlens_q,
block_tables, and state_slot_mapping.
Expand Down Expand Up @@ -386,6 +492,13 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int):
attn_metadata.start_pos_per_seq_cpu = positions_np[
cu_seqlens_q_np[:scheduled_bs]
]
self._attach_sparse_layout_metadata(
attn_metadata,
cu_seqlens_q_np,
attn_metadata.start_pos_per_seq_cpu,
scheduled_bs,
sum_scheduled_tokens,
)
return attn_metadata, positions

def prepare_prefill(self, batch: ScheduledBatch):
Expand Down Expand Up @@ -422,6 +535,13 @@ def prepare_prefill(self, batch: ScheduledBatch):
attn_metadata.start_pos_per_seq_cpu = positions_np[
cu_seqlens_q_np[:scheduled_bs]
]
self._attach_sparse_layout_metadata(
attn_metadata,
cu_seqlens_q_np,
attn_metadata.start_pos_per_seq_cpu,
scheduled_bs,
sum_scheduled_tokens,
)
return attn_metadata, positions

def _populate_block_tables(
Expand Down
Loading