Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4d99805
Add DSv4 sparse attention and indexer Triton ops
Oseltamivir Apr 30, 2026
c2eafc5
Tile DSv4 sparse MQA output dimension
Oseltamivir Apr 30, 2026
13b6b39
Retile DSv4 sparse MQA sink
Oseltamivir May 1, 2026
2616197
rm tests
Oseltamivir May 1, 2026
7c5fd02
Revert "rm tests"
Oseltamivir May 1, 2026
19d07bd
Address sparse MQA review comments
Oseltamivir May 1, 2026
914786b
Avoid sync in sparse MQA input checks
Oseltamivir May 1, 2026
aeb8946
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 1, 2026
0923d27
Add batched DSv4 indexer coverage
Oseltamivir May 1, 2026
aa0c5b6
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 1, 2026
220bd4d
Format DSv4 sparse indexer files
Oseltamivir May 3, 2026
085989c
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 3, 2026
883ddb7
fix: make topk per row width configurable
Oseltamivir May 3, 2026
969863a
Merge remote-tracking branch 'upstream/main' into dsv4-sparse-indexer-pr
Oseltamivir May 4, 2026
6bb83f6
Fix mhc pre sqrsum row race
Oseltamivir May 5, 2026
e90f679
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 5, 2026
ef16939
Fix small partial-M blockscale GEMM drift
Oseltamivir May 5, 2026
3b4aad6
Fix DSv4 wo_b blockscale GEMM dispatch
Oseltamivir May 5, 2026
47448f5
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 5, 2026
a42ec8f
Route DSv4 wo_b partial GEMM to CKTile
Oseltamivir May 5, 2026
852979e
Use CK for DSv4 partial wo_b
Oseltamivir May 6, 2026
a1019b4
Merge remote-tracking branch 'upstream/main' into dsv4-sparse-indexer-pr
Oseltamivir May 6, 2026
c13e789
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 6, 2026
8beecb2
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aiter/ops/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
"pa_prefill": "attention.pa_prefill",
"pod_attention": "attention.pod_attention",
"prefill_attention": "attention.prefill_attention",
"dsv4_indexer": "attention.dsv4_indexer",
"sparse_mqa_sink": "attention.sparse_mqa_sink",
"unified_attention_sparse_mla": "attention.unified_attention_sparse_mla",
"unified_attention": "attention.unified_attention",
# Fusions modules (fusions/)
Expand Down
241 changes: 241 additions & 0 deletions aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import triton
import triton.language as tl


@triton.jit
def _dsv4_indexer_dense_kernel(
out_ptr, # [num_tokens, topk]
positions_ptr, # [num_tokens]
out_stride_t: tl.int64,
out_stride_k: tl.int64,
n_committed: tl.constexpr,
offset: tl.int32,
ratio: tl.constexpr,
BLOCK_K: tl.constexpr,
):
token_id = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_K)
pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_k < n_committed) & (offs_k < causal_limit)
out = tl.where(valid, offs_k + offset, -1)
tl.store(
out_ptr + token_id * out_stride_t + offs_k * out_stride_k,
out,
mask=offs_k < n_committed,
)


@triton.jit
def _dsv4_indexer_dense_batched_kernel(
out_ptr, # [num_tokens, topk]
positions_ptr, # [num_tokens]
seq_ids_ptr, # [num_tokens]
kv_lens_ptr, # [num_seqs]
out_stride_t: tl.int64,
out_stride_k: tl.int64,
n_committed: tl.constexpr,
offset: tl.int32,
ratio: tl.constexpr,
BLOCK_K: tl.constexpr,
):
token_id = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_K)
seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32)
kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32)
pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_k < n_committed) & (offs_k < kv_len) & (offs_k < causal_limit)
out = tl.where(valid, offs_k + offset, -1)
tl.store(
out_ptr + token_id * out_stride_t + offs_k * out_stride_k,
out,
mask=offs_k < n_committed,
)


@triton.jit
def _dsv4_indexer_score_kernel(
score_ptr, # [num_tokens, kv_len], fp32
q_ptr, # [num_tokens, num_heads, head_dim]
kv_ptr, # [kv_len, head_dim]
weights_ptr, # [num_tokens, num_heads]
positions_ptr, # [num_tokens]
q_stride_t: tl.int64,
q_stride_h: tl.int64,
q_stride_d: tl.int64,
kv_stride_t: tl.int64,
kv_stride_d: tl.int64,
weights_stride_t: tl.int64,
weights_stride_h: tl.int64,
score_stride_t: tl.int64,
score_stride_k: tl.int64,
num_heads: tl.constexpr,
head_dim: tl.constexpr,
kv_len: tl.constexpr,
ratio: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_D: tl.constexpr,
):
token_id = tl.program_id(0)
tile_id = tl.program_id(1)

offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T)
offs_d = tl.arange(0, BLOCK_D)
d_mask = offs_d < head_dim
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)

for h_start in range(0, num_heads, BLOCK_H):
offs_h = h_start + tl.arange(0, BLOCK_H)
h_mask = offs_h < num_heads

q = tl.load(
q_ptr
+ token_id * q_stride_t
+ offs_h[:, None] * q_stride_h
+ offs_d[None, :] * q_stride_d,
mask=h_mask[:, None] & d_mask[None, :],
other=0.0,
cache_modifier=".cg",
)
kv = tl.load(
kv_ptr + offs_t[None, :] * kv_stride_t + offs_d[:, None] * kv_stride_d,
mask=(offs_t[None, :] < kv_len) & d_mask[:, None],
other=0.0,
cache_modifier=".cg",
)
dots = tl.dot(q, kv)
dots = tl.maximum(dots, 0.0)
weights = tl.load(
weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h,
mask=h_mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
acc += tl.sum(dots * weights[:, None], axis=0)

pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_t < kv_len) & (offs_t < causal_limit)
acc = tl.where(valid, acc, float("-inf"))
tl.store(
score_ptr + token_id * score_stride_t + offs_t * score_stride_k,
acc,
mask=offs_t < kv_len,
)


@triton.jit
def _dsv4_indexer_score_batched_kernel(
score_ptr, # [num_tokens, kv_len], fp32
q_ptr, # [num_tokens, num_heads, head_dim]
kv_ptr, # [num_seqs, kv_len, head_dim]
weights_ptr, # [num_tokens, num_heads]
positions_ptr, # [num_tokens]
seq_ids_ptr, # [num_tokens]
kv_lens_ptr, # [num_seqs]
q_stride_t: tl.int64,
q_stride_h: tl.int64,
q_stride_d: tl.int64,
kv_stride_b: tl.int64,
kv_stride_t: tl.int64,
kv_stride_d: tl.int64,
weights_stride_t: tl.int64,
weights_stride_h: tl.int64,
score_stride_t: tl.int64,
score_stride_k: tl.int64,
num_heads: tl.constexpr,
head_dim: tl.constexpr,
kv_len_max: tl.constexpr,
ratio: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_D: tl.constexpr,
):
token_id = tl.program_id(0)
tile_id = tl.program_id(1)

seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32)
kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32)
offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T)
offs_d = tl.arange(0, BLOCK_D)
d_mask = offs_d < head_dim
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)

for h_start in range(0, num_heads, BLOCK_H):
offs_h = h_start + tl.arange(0, BLOCK_H)
h_mask = offs_h < num_heads

q = tl.load(
q_ptr
+ token_id * q_stride_t
+ offs_h[:, None] * q_stride_h
+ offs_d[None, :] * q_stride_d,
mask=h_mask[:, None] & d_mask[None, :],
other=0.0,
cache_modifier=".cg",
)
kv = tl.load(
kv_ptr
+ seq_id * kv_stride_b
+ offs_t[None, :] * kv_stride_t
+ offs_d[:, None] * kv_stride_d,
mask=(offs_t[None, :] < kv_len) & d_mask[:, None],
other=0.0,
cache_modifier=".cg",
)
dots = tl.dot(q, kv)
dots = tl.maximum(dots, 0.0)
weights = tl.load(
weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h,
mask=h_mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
acc += tl.sum(dots * weights[:, None], axis=0)

pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_t < kv_len_max) & (offs_t < kv_len) & (offs_t < causal_limit)
acc = tl.where(valid, acc, float("-inf"))
tl.store(
score_ptr + token_id * score_stride_t + offs_t * score_stride_k,
acc,
mask=offs_t < kv_len_max,
)


@triton.jit
def _dsv4_indexer_finalize_kernel(
out_ptr, # [num_tokens, topk], int32
values_ptr, # [num_tokens, topk], fp32
indices_ptr, # [num_tokens, topk], int64 from aiter topk
out_stride_t: tl.int64,
out_stride_k: tl.int64,
values_stride_t: tl.int64,
values_stride_k: tl.int64,
indices_stride_t: tl.int64,
indices_stride_k: tl.int64,
offset: tl.int32,
topk: tl.constexpr,
BLOCK_K: tl.constexpr,
):
token_id = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_K)
values = tl.load(
values_ptr + token_id * values_stride_t + offs_k * values_stride_k,
mask=offs_k < topk,
other=float("-inf"),
)
indices = tl.load(
indices_ptr + token_id * indices_stride_t + offs_k * indices_stride_k,
mask=offs_k < topk,
other=-1,
).to(tl.int32)
out = tl.where(values > -3.0e38, indices + offset, -1)
tl.store(
out_ptr + token_id * out_stride_t + offs_k * out_stride_k,
out,
mask=offs_k < topk,
)
Loading