From c060828b1f459c09822a386891a93188ea660445 Mon Sep 17 00:00:00 2001 From: chenjun Date: Fri, 1 May 2026 06:43:23 -0500 Subject: [PATCH 1/2] use triton sparse_attn_ragged --- atom/model_ops/sparse_attn_v4.py | 371 ++++++++++++++++++++++++++++++- atom/models/deepseek_v4.py | 47 +++- 2 files changed, 406 insertions(+), 12 deletions(-) diff --git a/atom/model_ops/sparse_attn_v4.py b/atom/model_ops/sparse_attn_v4.py index 790bf5eb00..890d44d9dc 100644 --- a/atom/model_ops/sparse_attn_v4.py +++ b/atom/model_ops/sparse_attn_v4.py @@ -13,16 +13,368 @@ kernel's accumulation precision. They are correct but not performant. """ +import os from typing import Tuple import torch +import triton +import triton.language as tl # --------------------------------------------------------------------------- # sparse_attn — FlashAttention-style sparse MQA with attention sink # --------------------------------------------------------------------------- -def sparse_attn( +@triton.jit +def _sparse_attn_triton_kernel( + q_ptr, + kv_ptr, + attn_sink_ptr, + topk_idxs_ptr, + out_ptr, + q_stride_b: tl.constexpr, + q_stride_m: tl.constexpr, + q_stride_h: tl.constexpr, + q_stride_d: tl.constexpr, + kv_stride_b: tl.constexpr, + kv_stride_n: tl.constexpr, + kv_stride_d: tl.constexpr, + topk_stride_b: tl.constexpr, + topk_stride_m: tl.constexpr, + topk_stride_k: tl.constexpr, + out_stride_b: tl.constexpr, + out_stride_m: tl.constexpr, + out_stride_h: tl.constexpr, + out_stride_d: tl.constexpr, + M: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + K: tl.constexpr, + softmax_scale: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_bm = tl.program_id(0) + pid_h = tl.program_id(1) + m = pid_bm % M + b = pid_bm // M + + h_offs = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) + d_offs = tl.arange(0, BLOCK_D) + h_mask = h_offs < H + d_mask = d_offs < D + + q_base = b * q_stride_b + m * q_stride_m + q = tl.load( + q_ptr + q_base + h_offs[:, None] * q_stride_h + d_offs[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + ) + + neg_large = -3.4028234663852886e38 + m_i = tl.full((BLOCK_H,), neg_large, dtype=tl.float32) + l_i = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + k_offs = tl.arange(0, BLOCK_K) + for k_start in tl.range(0, K, BLOCK_K): + k_pos = k_start + k_offs + in_range = k_pos < K + idx = tl.load( + topk_idxs_ptr + + b * topk_stride_b + + m * topk_stride_m + + k_pos * topk_stride_k, + mask=in_range, + other=-1, + ) + valid = in_range & (idx >= 0) + + kv = tl.load( + kv_ptr + + b * kv_stride_b + + idx[:, None] * kv_stride_n + + d_offs[None, :] * kv_stride_d, + mask=valid[:, None] & d_mask[None, :], + other=0.0, + ) + + scores = tl.dot(q, tl.trans(kv)) * softmax_scale + scores = tl.where(h_mask[:, None] & valid[None, :], scores, neg_large) + + m_block = tl.max(scores, axis=1) + m_new = tl.maximum(m_i, m_block) + alpha = tl.exp(m_i - m_new) + p = tl.exp(scores - m_new[:, None]) + p = tl.where(h_mask[:, None] & valid[None, :], p, 0.0) + l_new = l_i * alpha + tl.sum(p, axis=1) + + acc = acc * alpha[:, None] + tl.dot(p.to(kv.dtype), kv) + m_i = m_new + l_i = l_new + + sink = tl.load(attn_sink_ptr + h_offs, mask=h_mask, other=neg_large).to(tl.float32) + m_final = tl.maximum(m_i, sink) + l_final = l_i * tl.exp(m_i - m_final) + tl.exp(sink - m_final) + + denom = tl.maximum(l_final, 1.0e-30) + out = tl.where(l_final[:, None] > 0.0, acc / denom[:, None], 0.0) + out_base = b * out_stride_b + m * out_stride_m + tl.store( + out_ptr + + out_base + + h_offs[:, None] * out_stride_h + + d_offs[None, :] * out_stride_d, + out, + mask=h_mask[:, None] & d_mask[None, :], + ) + + +def _sparse_attn_triton( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_idxs: torch.Tensor, + softmax_scale: float, +) -> torch.Tensor: + if not q.is_cuda: + raise RuntimeError("Triton sparse_attn requires CUDA/HIP tensors") + if q.dtype not in (torch.bfloat16, torch.float16): + raise RuntimeError(f"Triton sparse_attn expects fp16/bf16 q, got {q.dtype}") + if kv.dtype != q.dtype: + raise RuntimeError( + f"Triton sparse_attn expects kv dtype {q.dtype}, got {kv.dtype}" + ) + + B, M, H, D = q.shape + K = topk_idxs.shape[-1] + out = torch.empty_like(q) + topk_idxs = topk_idxs.to(torch.int32) + + # Process a small head tile per program, matching the TileLang kernel's + # q[h, d] / acc[h, d] structure while keeping V4's D=512 register + # pressure bounded. + block_h = 2 if D >= 256 else 4 + block_d = triton.next_power_of_2(D) + block_k = 16 if D >= 256 else 32 + _sparse_attn_triton_kernel[(B * M, triton.cdiv(H, block_h))]( + q, + kv, + attn_sink, + topk_idxs, + out, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv.stride(0), + kv.stride(1), + kv.stride(2), + topk_idxs.stride(0), + topk_idxs.stride(1), + topk_idxs.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + M, + H, + D, + K, + float(softmax_scale), + BLOCK_H=block_h, + BLOCK_D=block_d, + BLOCK_K=block_k, + num_warps=8, + ) + return out + + +@triton.jit +def _sparse_attn_ragged_triton_kernel( + q_ptr, + kv_ptr, + attn_sink_ptr, + topk_idxs_ptr, + out_ptr, + q_stride_t: tl.constexpr, + q_stride_h: tl.constexpr, + q_stride_d: tl.constexpr, + kv_stride_n: tl.constexpr, + kv_stride_d: tl.constexpr, + topk_stride_t: tl.constexpr, + topk_stride_k: tl.constexpr, + out_stride_t: tl.constexpr, + out_stride_h: tl.constexpr, + out_stride_d: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + K: tl.constexpr, + softmax_scale: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_K: tl.constexpr, +): + t = tl.program_id(0) + pid_h = tl.program_id(1) + + h_offs = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) + d_offs = tl.arange(0, BLOCK_D) + h_mask = h_offs < H + d_mask = d_offs < D + + q = tl.load( + q_ptr + + t * q_stride_t + + h_offs[:, None] * q_stride_h + + d_offs[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + ) + + neg_large = -3.4028234663852886e38 + m_i = tl.full((BLOCK_H,), neg_large, dtype=tl.float32) + l_i = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + k_offs = tl.arange(0, BLOCK_K) + for k_start in tl.range(0, K, BLOCK_K): + k_pos = k_start + k_offs + in_range = k_pos < K + idx = tl.load( + topk_idxs_ptr + t * topk_stride_t + k_pos * topk_stride_k, + mask=in_range, + other=-1, + ) + valid = in_range & (idx >= 0) + + kv = tl.load( + kv_ptr + idx[:, None] * kv_stride_n + d_offs[None, :] * kv_stride_d, + mask=valid[:, None] & d_mask[None, :], + other=0.0, + ) + + scores = tl.dot(q, tl.trans(kv)) * softmax_scale + scores = tl.where(h_mask[:, None] & valid[None, :], scores, neg_large) + + m_block = tl.max(scores, axis=1) + m_new = tl.maximum(m_i, m_block) + alpha = tl.exp(m_i - m_new) + p = tl.exp(scores - m_new[:, None]) + p = tl.where(h_mask[:, None] & valid[None, :], p, 0.0) + l_new = l_i * alpha + tl.sum(p, axis=1) + + acc = acc * alpha[:, None] + tl.dot(p.to(kv.dtype), kv) + m_i = m_new + l_i = l_new + + sink = tl.load(attn_sink_ptr + h_offs, mask=h_mask, other=neg_large).to(tl.float32) + m_final = tl.maximum(m_i, sink) + l_final = l_i * tl.exp(m_i - m_final) + tl.exp(sink - m_final) + + denom = tl.maximum(l_final, 1.0e-30) + out = tl.where(l_final[:, None] > 0.0, acc / denom[:, None], 0.0) + tl.store( + out_ptr + + t * out_stride_t + + h_offs[:, None] * out_stride_h + + d_offs[None, :] * out_stride_d, + out, + mask=h_mask[:, None] & d_mask[None, :], + ) + + +def _sparse_attn_ragged_triton( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_idxs: torch.Tensor, + softmax_scale: float, +) -> torch.Tensor: + if not q.is_cuda: + raise RuntimeError("Triton sparse_attn_ragged requires CUDA/HIP tensors") + if q.dtype not in (torch.bfloat16, torch.float16): + raise RuntimeError( + f"Triton sparse_attn_ragged expects fp16/bf16 q, got {q.dtype}" + ) + if kv.dtype != q.dtype: + raise RuntimeError( + f"Triton sparse_attn_ragged expects kv dtype {q.dtype}, got {kv.dtype}" + ) + + T, H, D = q.shape + K = topk_idxs.shape[-1] + out = torch.empty_like(q) + topk_idxs = topk_idxs.to(torch.int32) + + block_h = 2 if D >= 256 else 4 + block_d = triton.next_power_of_2(D) + block_k = 16 if D >= 256 else 32 + _sparse_attn_ragged_triton_kernel[(T, triton.cdiv(H, block_h))]( + q, + kv, + attn_sink, + topk_idxs, + out, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + topk_idxs.stride(0), + topk_idxs.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + H, + D, + K, + float(softmax_scale), + BLOCK_H=block_h, + BLOCK_D=block_d, + BLOCK_K=block_k, + num_warps=8, + ) + return out + + +def _sparse_attn_ragged_torch( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_idxs: torch.Tensor, + softmax_scale: float, +) -> torch.Tensor: + return _sparse_attn_torch( + q.unsqueeze(0), + kv.unsqueeze(0), + attn_sink, + topk_idxs.unsqueeze(0), + softmax_scale, + ).squeeze(0) + + +def sparse_attn_ragged( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_idxs: torch.Tensor, + softmax_scale: float, +) -> torch.Tensor: + """Sparse attention over flat ragged sequences. + + Args: + q: [num_tokens, H, D] + kv: [total_kv, D] + topk_idxs: [num_tokens, K] global indices into `kv`; -1 entries are skipped. + """ + if os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1": + return _sparse_attn_ragged_triton(q, kv, attn_sink, topk_idxs, softmax_scale) + return _sparse_attn_ragged_torch(q, kv, attn_sink, topk_idxs, softmax_scale) + + +def _sparse_attn_torch( q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, @@ -121,6 +473,23 @@ def sparse_attn( return out.to(out_dtype) +def sparse_attn( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_idxs: torch.Tensor, + softmax_scale: float, +) -> torch.Tensor: + """Sparse multi-head attention with optional Triton backend. + + Set `ATOM_USE_TRITON_ATTN=1` to use the Triton kernel. The torch + implementation remains the default when the env var is unset. + """ + if os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1": + return _sparse_attn_triton(q, kv, attn_sink, topk_idxs, softmax_scale) + return _sparse_attn_torch(q, kv, attn_sink, topk_idxs, softmax_scale) + + # --------------------------------------------------------------------------- # hc_split_sinkhorn — Manifold-Constrained Hyper-Connections projection # --------------------------------------------------------------------------- diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index a96cddc3cc..b7c38ae680 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -40,7 +40,10 @@ fp4_act_quant_inplace, rotate_activation, ) -from atom.model_ops.sparse_attn_v4 import hc_split_sinkhorn, sparse_attn # noqa: F401 +from atom.model_ops.sparse_attn_v4 import ( # noqa: F401 + hc_split_sinkhorn, + sparse_attn_ragged, +) from atom.model_ops.utils import atom_parameter from atom.model_ops.v4_backend_gate import use_new_v4_backend # noqa: F401 from atom.model_ops.v4_kernels import ( # noqa: F401 @@ -1122,6 +1125,7 @@ def __init__(self, layer_id: int, args: DeepseekV4Args, prefix: str = ""): prefix=f"{p}.wq_a", ) self.q_norm = RMSNorm(self.q_lora_rank, self.eps) + self.q_norm2 = RMSNorm(self.head_dim, self.eps) self.wq_b = ColumnParallelLinear( self.q_lora_rank, self.n_heads * self.head_dim, @@ -1354,7 +1358,8 @@ def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: slots = state_slot_mapping[:num_seqs].tolist() per_seq_bts = [block_tables[i] for i in range(num_seqs)] - output_chunks = [] + sparse_kvs = [] + sparse_topks = [] for seq_idx in range(num_seqs): seq_start = seq_offsets[seq_idx] seq_end = seq_offsets[seq_idx + 1] @@ -1365,7 +1370,6 @@ def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: block_table = per_seq_bts[seq_idx] seq_positions = positions[seq_start:seq_end] seq_x = x[seq_start:seq_end] - seq_q = q[:, seq_start:seq_end] seq_kv = kv[:, seq_start:seq_end] # PR-A Phase 2: prefer CPU mirror to avoid per-seq GPU→CPU sync # (~64 layers × num_seqs syncs eliminated when present). @@ -1493,16 +1497,37 @@ def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: else: kv_sa = self.swa_kv[slot : slot + 1] - o_seq = sparse_attn( - seq_q, kv_sa, self.attn_sink, topk_idxs, self.softmax_scale + sparse_kvs.append(kv_sa.squeeze(0)) + sparse_topks.append(topk_idxs.squeeze(0)) + + # Ragged sparse attention: concatenate per-seq queries and KV pools, + # then rewrite each token's local topk indices into global KV offsets. + q_sa = q.squeeze(0).contiguous() + kv_sa = torch.cat(sparse_kvs, dim=0).contiguous() + max_topk = max(t.size(-1) for t in sparse_topks) + topk_sa = torch.full( + (q_sa.size(0), max_topk), + -1, + device=x.device, + dtype=torch.int32, + ) + kv_offsets = torch.empty(q_sa.size(0), device=x.device, dtype=torch.int32) + token_base = 0 + kv_base = 0 + for seq_kv, seq_topk in zip(sparse_kvs, sparse_topks): + seq_len = seq_topk.size(0) + seq_topk = seq_topk.int() + topk_sa[token_base : token_base + seq_len, : seq_topk.size(-1)] = ( + seq_topk ) - output_chunks.append(o_seq) + kv_offsets[token_base : token_base + seq_len] = kv_base + token_base += seq_len + kv_base += seq_kv.size(0) + topk_sa = torch.where(topk_sa >= 0, topk_sa + kv_offsets[:, None], topk_sa) - o = ( - torch.cat(output_chunks, dim=1) - if len(output_chunks) > 1 - else output_chunks[0] - ) + o = sparse_attn_ragged( + q_sa, kv_sa, self.attn_sink, topk_sa, self.softmax_scale + ).unsqueeze(0) # Inverse RoPE on output's rope dims to remove absolute-position # contribution carried in by the value-side RoPE of the KV entries. From e630b039f02411fa95c40d2c013b006fbd63381b Mon Sep 17 00:00:00 2001 From: chenjun Date: Fri, 1 May 2026 08:17:13 -0500 Subject: [PATCH 2/2] use triton sparse_attn_ragged_varlen --- atom/model_ops/attentions/deepseek_v4_attn.py | 120 +++++++++ atom/model_ops/sparse_attn_v4.py | 238 ++++++++++++++++++ atom/models/deepseek_v4.py | 60 +++-- 3 files changed, 393 insertions(+), 25 deletions(-) diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index 190eace489..f942731230 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -44,6 +44,7 @@ AttentionMetadataBuilder, CommonAttentionBuilder, ) +from atom.utils import CpuGpuBuffer from atom.utils.forward_context import AttentionMetaData @@ -96,6 +97,7 @@ def __init__(self, model_runner): self.head_dim = getattr(hf, "kv_head_dim", 512) self.index_head_dim = getattr(hf, "index_head_dim", 128) self.window_size = getattr(hf, "sliding_window", 128) + self.index_topk = getattr(hf, "index_topk", 1024) # Compressor state shape: [coff * ratio, coff * head_dim], fp32. # CSA: ratio=4, overlap=True -> coff=2 -> [8, 2*head_dim] @@ -114,6 +116,24 @@ def __init__(self, model_runner): self._swa_dtype = torch.bfloat16 # SWA window matches KV dtype self._classical_dtype = torch.bfloat16 # compressed KV is BF16 + # Sparse-attn layout metadata, aligned with aiter_mla's CpuGpuBuffer + # convention. Values are per-token and reusable across layers with the + # same compress_ratio; the actual topk indices remain layer-specific. + i32_kwargs = {"dtype": torch.int32, "device": self.device} + i64_kwargs = {"dtype": torch.int64, "device": self.device} + v4_sparse_metadata = {} + for kind in ("dense", "csa", "hca"): + v4_sparse_metadata[f"v4_{kind}_sparse_topk_starts"] = CpuGpuBuffer( + self.max_num_batched_tokens, **i64_kwargs + ) + v4_sparse_metadata[f"v4_{kind}_sparse_topk_lens"] = CpuGpuBuffer( + self.max_num_batched_tokens, **i32_kwargs + ) + v4_sparse_metadata[f"v4_{kind}_sparse_kv_offsets"] = CpuGpuBuffer( + self.max_num_batched_tokens, **i32_kwargs + ) + self.model_runner.forward_vars.update(v4_sparse_metadata) + # ------------------------------------------------------------------ # # AttentionMetadataBuilder hooks (per-request cache abstraction). # # ------------------------------------------------------------------ # @@ -314,6 +334,92 @@ def build_kv_cache_tensor(self, layer_id: int, module): # the rest of ATOM and to support PR3-main multi-sequence wiring). # # ------------------------------------------------------------------ # + def _attach_sparse_layout_metadata( + self, + attn_metadata: AttentionMetaData, + cu_seqlens_q_np, + start_pos_per_seq, + scheduled_bs: int, + total_tokens: int, + ) -> None: + """Precompute per-token ragged sparse-attn layout for each ratio type. + + The actual topk index values are layer-specific (CSA Indexer depends on + weights), but the per-token topk span and global-KV offset layout only + depends on the request geometry and compress_ratio. + """ + import numpy as np + + var = self.model_runner.forward_vars + layouts = {} + ratio_specs = { + 0: ("dense", 0), + 4: ("csa", 4), + 128: ("hca", 128), + } + for ratio_key, (kind, ratio) in ratio_specs.items(): + starts = var[f"v4_{kind}_sparse_topk_starts"].np + lens = var[f"v4_{kind}_sparse_topk_lens"].np + offsets = var[f"v4_{kind}_sparse_kv_offsets"].np + topk_base = 0 + kv_base = 0 + token_base = 0 + max_topk = 0 + for seq_idx in range(scheduled_bs): + seq_start = int(cu_seqlens_q_np[seq_idx]) + seq_end = int(cu_seqlens_q_np[seq_idx + 1]) + token_num = seq_end - seq_start + if token_num == 0: + continue + start_pos = int(start_pos_per_seq[seq_idx]) + end_pos = start_pos + token_num + window_topk = self.window_size if start_pos > 0 else min( + token_num, self.window_size + ) + compress_topk = 0 + kv_len = token_num if start_pos == 0 else self.window_size + if ratio == 4: + compress_topk = min(self.index_topk, end_pos // ratio) + if start_pos == 0: + kv_len = token_num + (token_num + ratio - 1) // ratio + else: + kv_len = self.window_size + end_pos // ratio + elif ratio == 128: + if start_pos > 0: + compress_topk = (start_pos + 1) // ratio + else: + compress_topk = token_num // ratio + if start_pos == 0: + kv_len = token_num + (token_num + ratio - 1) // ratio + else: + kv_len = self.window_size + end_pos // ratio + topk_len = window_topk + compress_topk + starts[token_base : token_base + token_num] = np.arange( + topk_base, + topk_base + token_num * topk_len, + topk_len, + dtype=np.int64, + ) + lens[token_base : token_base + token_num] = topk_len + offsets[token_base : token_base + token_num] = kv_base + token_base += token_num + topk_base += token_num * topk_len + kv_base += kv_len + max_topk = max(max_topk, topk_len) + + topk_starts = var[f"v4_{kind}_sparse_topk_starts"].copy_to_gpu( + total_tokens + ) + topk_lens = var[f"v4_{kind}_sparse_topk_lens"].copy_to_gpu(total_tokens) + kv_offsets = var[f"v4_{kind}_sparse_kv_offsets"].copy_to_gpu(total_tokens) + layouts[ratio_key] = { + "topk_starts": topk_starts, + "topk_lens": topk_lens, + "kv_offsets": kv_offsets, + "max_topk": max_topk, + } + attn_metadata.v4_sparse_layouts = layouts + def prepare_decode(self, batch: ScheduledBatch, bs: int): """V4-style decode prep: populates positions, cu_seqlens_q, block_tables, and state_slot_mapping. @@ -386,6 +492,13 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): attn_metadata.start_pos_per_seq_cpu = positions_np[ cu_seqlens_q_np[:scheduled_bs] ] + self._attach_sparse_layout_metadata( + attn_metadata, + cu_seqlens_q_np, + attn_metadata.start_pos_per_seq_cpu, + scheduled_bs, + sum_scheduled_tokens, + ) return attn_metadata, positions def prepare_prefill(self, batch: ScheduledBatch): @@ -422,6 +535,13 @@ def prepare_prefill(self, batch: ScheduledBatch): attn_metadata.start_pos_per_seq_cpu = positions_np[ cu_seqlens_q_np[:scheduled_bs] ] + self._attach_sparse_layout_metadata( + attn_metadata, + cu_seqlens_q_np, + attn_metadata.start_pos_per_seq_cpu, + scheduled_bs, + sum_scheduled_tokens, + ) return attn_metadata, positions def _populate_block_tables( diff --git a/atom/model_ops/sparse_attn_v4.py b/atom/model_ops/sparse_attn_v4.py index 890d44d9dc..842ec73a9a 100644 --- a/atom/model_ops/sparse_attn_v4.py +++ b/atom/model_ops/sparse_attn_v4.py @@ -20,6 +20,19 @@ import triton import triton.language as tl + +def _bucket_topk(max_topk: int) -> int: + """Limit Triton specializations while keeping K_MAX constexpr.""" + if max_topk <= 0: + return 1 + # Buckets cover V4's common window/indexer regimes without generating a new + # Triton specialization for every prefill length. + for bucket in (128, 256, 512, 1024, 2048, 4096): + if max_topk <= bucket: + return bucket + return triton.next_power_of_2(max_topk) + + # --------------------------------------------------------------------------- # sparse_attn — FlashAttention-style sparse MQA with attention sink # --------------------------------------------------------------------------- @@ -339,6 +352,231 @@ def _sparse_attn_ragged_triton( return out +@triton.jit +def _sparse_attn_ragged_varlen_triton_kernel( + q_ptr, + kv_ptr, + attn_sink_ptr, + topk_flat_ptr, + topk_starts_ptr, + topk_lens_ptr, + kv_offsets_ptr, + out_ptr, + q_stride_t: tl.constexpr, + q_stride_h: tl.constexpr, + q_stride_d: tl.constexpr, + kv_stride_n: tl.constexpr, + kv_stride_d: tl.constexpr, + out_stride_t: tl.constexpr, + out_stride_h: tl.constexpr, + out_stride_d: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + K_MAX: tl.constexpr, + softmax_scale: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_K: tl.constexpr, +): + t = tl.program_id(0) + pid_h = tl.program_id(1) + + h_offs = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) + d_offs = tl.arange(0, BLOCK_D) + h_mask = h_offs < H + d_mask = d_offs < D + + q = tl.load( + q_ptr + + t * q_stride_t + + h_offs[:, None] * q_stride_h + + d_offs[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + ) + topk_start = tl.load(topk_starts_ptr + t) + topk_len = tl.load(topk_lens_ptr + t) + kv_offset = tl.load(kv_offsets_ptr + t) + + neg_large = -3.4028234663852886e38 + m_i = tl.full((BLOCK_H,), neg_large, dtype=tl.float32) + l_i = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + k_offs = tl.arange(0, BLOCK_K) + for k_start in tl.range(0, K_MAX, BLOCK_K): + k_pos = k_start + k_offs + in_range = k_pos < topk_len + idx_local = tl.load( + topk_flat_ptr + topk_start + k_pos, + mask=in_range, + other=-1, + ) + valid = in_range & (idx_local >= 0) + idx = idx_local + kv_offset + + kv = tl.load( + kv_ptr + idx[:, None] * kv_stride_n + d_offs[None, :] * kv_stride_d, + mask=valid[:, None] & d_mask[None, :], + other=0.0, + ) + + scores = tl.dot(q, tl.trans(kv)) * softmax_scale + scores = tl.where(h_mask[:, None] & valid[None, :], scores, neg_large) + + m_block = tl.max(scores, axis=1) + m_new = tl.maximum(m_i, m_block) + alpha = tl.exp(m_i - m_new) + p = tl.exp(scores - m_new[:, None]) + p = tl.where(h_mask[:, None] & valid[None, :], p, 0.0) + l_new = l_i * alpha + tl.sum(p, axis=1) + + acc = acc * alpha[:, None] + tl.dot(p.to(kv.dtype), kv) + m_i = m_new + l_i = l_new + + sink = tl.load(attn_sink_ptr + h_offs, mask=h_mask, other=neg_large).to(tl.float32) + m_final = tl.maximum(m_i, sink) + l_final = l_i * tl.exp(m_i - m_final) + tl.exp(sink - m_final) + + denom = tl.maximum(l_final, 1.0e-30) + out = tl.where(l_final[:, None] > 0.0, acc / denom[:, None], 0.0) + tl.store( + out_ptr + + t * out_stride_t + + h_offs[:, None] * out_stride_h + + d_offs[None, :] * out_stride_d, + out, + mask=h_mask[:, None] & d_mask[None, :], + ) + + +def _sparse_attn_ragged_varlen_triton( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_flat: torch.Tensor, + topk_starts: torch.Tensor, + topk_lens: torch.Tensor, + kv_offsets: torch.Tensor, + max_topk: int, + softmax_scale: float, +) -> torch.Tensor: + if not q.is_cuda: + raise RuntimeError("Triton sparse_attn_ragged_varlen requires CUDA/HIP tensors") + if q.dtype not in (torch.bfloat16, torch.float16): + raise RuntimeError( + f"Triton sparse_attn_ragged_varlen expects fp16/bf16 q, got {q.dtype}" + ) + if kv.dtype != q.dtype: + raise RuntimeError( + f"Triton sparse_attn_ragged_varlen expects kv dtype {q.dtype}, got {kv.dtype}" + ) + + T, H, D = q.shape + out = torch.empty_like(q) + topk_flat = topk_flat.to(torch.int32).contiguous() + topk_starts = topk_starts.to(torch.int64).contiguous() + topk_lens = topk_lens.to(torch.int32).contiguous() + kv_offsets = kv_offsets.to(torch.int32).contiguous() + + block_h = 2 if D >= 256 else 4 + block_d = triton.next_power_of_2(D) + block_k = 16 if D >= 256 else 32 + k_max = _bucket_topk(int(max_topk)) + _sparse_attn_ragged_varlen_triton_kernel[(T, triton.cdiv(H, block_h))]( + q, + kv, + attn_sink, + topk_flat, + topk_starts, + topk_lens, + kv_offsets, + out, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + H, + D, + k_max, + float(softmax_scale), + BLOCK_H=block_h, + BLOCK_D=block_d, + BLOCK_K=block_k, + num_warps=8, + ) + return out + + +def _sparse_attn_ragged_varlen_torch( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_flat: torch.Tensor, + topk_starts: torch.Tensor, + topk_lens: torch.Tensor, + kv_offsets: torch.Tensor, + max_topk: int, + softmax_scale: float, +) -> torch.Tensor: + topk_idxs = torch.full( + (q.size(0), max_topk), -1, device=q.device, dtype=torch.int32 + ) + for t in range(q.size(0)): + start = int(topk_starts[t].item()) + length = int(topk_lens[t].item()) + offset = int(kv_offsets[t].item()) + local = topk_flat[start : start + length].to(torch.int32) + topk_idxs[t, :length] = torch.where(local >= 0, local + offset, local) + return _sparse_attn_ragged_torch(q, kv, attn_sink, topk_idxs, softmax_scale) + + +def sparse_attn_ragged_varlen( + q: torch.Tensor, + kv: torch.Tensor, + attn_sink: torch.Tensor, + topk_flat: torch.Tensor, + topk_starts: torch.Tensor, + topk_lens: torch.Tensor, + kv_offsets: torch.Tensor, + max_topk: int, + softmax_scale: float, +) -> torch.Tensor: + """Sparse attention over flat ragged sequences with per-token topk spans. + + `topk_flat` stores local per-seq KV indices; `kv_offsets[t]` is added in the + kernel for valid entries to address the concatenated global KV pool. + """ + if os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1": + return _sparse_attn_ragged_varlen_triton( + q, + kv, + attn_sink, + topk_flat, + topk_starts, + topk_lens, + kv_offsets, + max_topk, + softmax_scale, + ) + return _sparse_attn_ragged_varlen_torch( + q, + kv, + attn_sink, + topk_flat, + topk_starts, + topk_lens, + kv_offsets, + max_topk, + softmax_scale, + ) + + def _sparse_attn_ragged_torch( q: torch.Tensor, kv: torch.Tensor, diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index b7c38ae680..26b8329b39 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -42,7 +42,7 @@ ) from atom.model_ops.sparse_attn_v4 import ( # noqa: F401 hc_split_sinkhorn, - sparse_attn_ragged, + sparse_attn_ragged_varlen, ) from atom.model_ops.utils import atom_parameter from atom.model_ops.v4_backend_gate import use_new_v4_backend # noqa: F401 @@ -1498,35 +1498,45 @@ def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: kv_sa = self.swa_kv[slot : slot + 1] sparse_kvs.append(kv_sa.squeeze(0)) - sparse_topks.append(topk_idxs.squeeze(0)) + sparse_topks.append(topk_idxs.reshape(-1)) # Ragged sparse attention: concatenate per-seq queries and KV pools, - # then rewrite each token's local topk indices into global KV offsets. + # and pass varlen topk spans so the kernel can add global KV offsets. q_sa = q.squeeze(0).contiguous() kv_sa = torch.cat(sparse_kvs, dim=0).contiguous() - max_topk = max(t.size(-1) for t in sparse_topks) - topk_sa = torch.full( - (q_sa.size(0), max_topk), - -1, - device=x.device, - dtype=torch.int32, + topk_flat = torch.cat(sparse_topks).contiguous() + ctx = get_forward_context() + sparse_layouts = ( + getattr(ctx.attn_metadata, "v4_sparse_layouts", None) + if ctx is not None and ctx.attn_metadata is not None + else None ) - kv_offsets = torch.empty(q_sa.size(0), device=x.device, dtype=torch.int32) - token_base = 0 - kv_base = 0 - for seq_kv, seq_topk in zip(sparse_kvs, sparse_topks): - seq_len = seq_topk.size(0) - seq_topk = seq_topk.int() - topk_sa[token_base : token_base + seq_len, : seq_topk.size(-1)] = ( - seq_topk - ) - kv_offsets[token_base : token_base + seq_len] = kv_base - token_base += seq_len - kv_base += seq_kv.size(0) - topk_sa = torch.where(topk_sa >= 0, topk_sa + kv_offsets[:, None], topk_sa) - - o = sparse_attn_ragged( - q_sa, kv_sa, self.attn_sink, topk_sa, self.softmax_scale + assert self.compress_ratio in (0, 4, 128), ( + f"unexpected V4 compress_ratio={self.compress_ratio}; " + "expected one of 0, 4, 128" + ) + if _v4_is_dummy_run(): + topk_starts = torch.zeros(q_sa.size(0), device=x.device, dtype=torch.int64) + topk_lens = torch.zeros(q_sa.size(0), device=x.device, dtype=torch.int32) + kv_offsets = torch.zeros(q_sa.size(0), device=x.device, dtype=torch.int32) + max_topk = 1 + else: + layout = sparse_layouts[self.compress_ratio] + topk_starts = layout["topk_starts"] + topk_lens = layout["topk_lens"] + kv_offsets = layout["kv_offsets"] + max_topk = layout["max_topk"] + + o = sparse_attn_ragged_varlen( + q_sa, + kv_sa, + self.attn_sink, + topk_flat, + topk_starts, + topk_lens, + kv_offsets, + max_topk, + self.softmax_scale, ).unsqueeze(0) # Inverse RoPE on output's rope dims to remove absolute-position