Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4d99805
Add DSv4 sparse attention and indexer Triton ops
Oseltamivir Apr 30, 2026
c2eafc5
Tile DSv4 sparse MQA output dimension
Oseltamivir Apr 30, 2026
13b6b39
Retile DSv4 sparse MQA sink
Oseltamivir May 1, 2026
2616197
rm tests
Oseltamivir May 1, 2026
7c5fd02
Revert "rm tests"
Oseltamivir May 1, 2026
19d07bd
Address sparse MQA review comments
Oseltamivir May 1, 2026
914786b
Avoid sync in sparse MQA input checks
Oseltamivir May 1, 2026
aeb8946
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 1, 2026
0923d27
Add batched DSv4 indexer coverage
Oseltamivir May 1, 2026
aa0c5b6
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 1, 2026
220bd4d
Format DSv4 sparse indexer files
Oseltamivir May 3, 2026
085989c
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 3, 2026
883ddb7
fix: make topk per row width configurable
Oseltamivir May 3, 2026
969863a
Merge remote-tracking branch 'upstream/main' into dsv4-sparse-indexer-pr
Oseltamivir May 4, 2026
6bb83f6
Fix mhc pre sqrsum row race
Oseltamivir May 5, 2026
e90f679
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 5, 2026
ef16939
Fix small partial-M blockscale GEMM drift
Oseltamivir May 5, 2026
3b4aad6
Fix DSv4 wo_b blockscale GEMM dispatch
Oseltamivir May 5, 2026
47448f5
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 5, 2026
a42ec8f
Route DSv4 wo_b partial GEMM to CKTile
Oseltamivir May 5, 2026
852979e
Use CK for DSv4 partial wo_b
Oseltamivir May 6, 2026
a1019b4
Merge remote-tracking branch 'upstream/main' into dsv4-sparse-indexer-pr
Oseltamivir May 6, 2026
c13e789
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 6, 2026
8beecb2
Merge branch 'main' into dsv4-sparse-indexer-pr
Oseltamivir May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion aiter/ops/gemm_op_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def get_CKGEMM_config(M: int, N: int, K: int, tuned_file=None):
key = (gfx, cu_num, padded_M, N, K) if has_gfx else (cu_num, padded_M, N, K)
config = _CKGEMM_CONFIG_CACHE[tuned_file].get(key, None)
if config is not None:
config = dict(config)
config["_matched_m"] = padded_M
if AITER_LOG_TUNED_CONFIG:
logger.info(
f"shape is M:{M}, N:{N}, K:{K}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned on cu_num = {cu_num} in {tuned_file} , kernel name is {config['kernelName']}!"
Expand Down Expand Up @@ -749,12 +751,32 @@ def gemm_a8w8_blockscale_bpreshuffle(
m = XQ.shape[0]
n = WQ.shape[0]
k = XQ.shape[1]
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)

# DSv4-Pro wo_b under TP8 uses local shape [M, 2048] x [7168, 2048].
# The tuned table only has the full M=20480 row. Batched eval/prefill emits
# partial-M fragments (for example M=5544) that route through padded tuned
# dispatch and have shown row-dependent BF16 drift for identical rows.
# Use generic CK for partial-M fragments; keep the tuned full-shape row
# intact for the throughput benchmark.
if dtype == dtypes.bf16 and n == 7168 and k == 2048 and m != 20480:
return gemm_a8w8_blockscale_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y)

config = get_CKGEMM_config(
m, n, k, AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE
)
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)

if config is not None:
libtype = config["libtype"]
# The ASM blockscale kernels are tuned for exact or padded M buckets.
# For small DSv4 partial-M projections (for example M=176/352 mapping
# to the 256/512 buckets), the padded ASM path can produce
# row-dependent BF16 ULP drift for identical rows. CK handles MNK
# padding internally and preserves row equivalence, so use it for
# these small partial-M cases.
matched_m = int(config.get("_matched_m", m))
if libtype == "asm" and matched_m != m and matched_m <= 512:
return gemm_a8w8_blockscale_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y)
if libtype == "cktile":
return gemm_a8w8_blockscale_bpreshuffle_cktile(XQ, WQ, x_scale, w_scale, Y)
elif libtype == "ck":
Expand Down
2 changes: 2 additions & 0 deletions aiter/ops/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
"pa_prefill": "attention.pa_prefill",
"pod_attention": "attention.pod_attention",
"prefill_attention": "attention.prefill_attention",
"dsv4_indexer": "attention.dsv4_indexer",
"sparse_mqa_sink": "attention.sparse_mqa_sink",
"unified_attention_sparse_mla": "attention.unified_attention_sparse_mla",
"unified_attention": "attention.unified_attention",
# Fusions modules (fusions/)
Expand Down
241 changes: 241 additions & 0 deletions aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import triton
import triton.language as tl


@triton.jit
def _dsv4_indexer_dense_kernel(
out_ptr, # [num_tokens, topk]
positions_ptr, # [num_tokens]
out_stride_t: tl.int64,
out_stride_k: tl.int64,
n_committed: tl.constexpr,
offset: tl.int32,
ratio: tl.constexpr,
BLOCK_K: tl.constexpr,
):
token_id = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_K)
pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_k < n_committed) & (offs_k < causal_limit)
out = tl.where(valid, offs_k + offset, -1)
tl.store(
out_ptr + token_id * out_stride_t + offs_k * out_stride_k,
out,
mask=offs_k < n_committed,
)


@triton.jit
def _dsv4_indexer_dense_batched_kernel(
out_ptr, # [num_tokens, topk]
positions_ptr, # [num_tokens]
seq_ids_ptr, # [num_tokens]
kv_lens_ptr, # [num_seqs]
out_stride_t: tl.int64,
out_stride_k: tl.int64,
n_committed: tl.constexpr,
offset: tl.int32,
ratio: tl.constexpr,
BLOCK_K: tl.constexpr,
):
token_id = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_K)
seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32)
kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32)
pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_k < n_committed) & (offs_k < kv_len) & (offs_k < causal_limit)
out = tl.where(valid, offs_k + offset, -1)
tl.store(
out_ptr + token_id * out_stride_t + offs_k * out_stride_k,
out,
mask=offs_k < n_committed,
)


@triton.jit
def _dsv4_indexer_score_kernel(
score_ptr, # [num_tokens, kv_len], fp32
q_ptr, # [num_tokens, num_heads, head_dim]
kv_ptr, # [kv_len, head_dim]
weights_ptr, # [num_tokens, num_heads]
positions_ptr, # [num_tokens]
q_stride_t: tl.int64,
q_stride_h: tl.int64,
q_stride_d: tl.int64,
kv_stride_t: tl.int64,
kv_stride_d: tl.int64,
weights_stride_t: tl.int64,
weights_stride_h: tl.int64,
score_stride_t: tl.int64,
score_stride_k: tl.int64,
num_heads: tl.constexpr,
head_dim: tl.constexpr,
kv_len: tl.constexpr,
ratio: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_D: tl.constexpr,
):
token_id = tl.program_id(0)
tile_id = tl.program_id(1)

offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T)
offs_d = tl.arange(0, BLOCK_D)
d_mask = offs_d < head_dim
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)

for h_start in range(0, num_heads, BLOCK_H):
offs_h = h_start + tl.arange(0, BLOCK_H)
h_mask = offs_h < num_heads

q = tl.load(
q_ptr
+ token_id * q_stride_t
+ offs_h[:, None] * q_stride_h
+ offs_d[None, :] * q_stride_d,
mask=h_mask[:, None] & d_mask[None, :],
other=0.0,
cache_modifier=".cg",
)
kv = tl.load(
kv_ptr + offs_t[None, :] * kv_stride_t + offs_d[:, None] * kv_stride_d,
mask=(offs_t[None, :] < kv_len) & d_mask[:, None],
other=0.0,
cache_modifier=".cg",
)
dots = tl.dot(q, kv)
dots = tl.maximum(dots, 0.0)
weights = tl.load(
weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h,
mask=h_mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
acc += tl.sum(dots * weights[:, None], axis=0)

pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_t < kv_len) & (offs_t < causal_limit)
acc = tl.where(valid, acc, float("-inf"))
tl.store(
score_ptr + token_id * score_stride_t + offs_t * score_stride_k,
acc,
mask=offs_t < kv_len,
)


@triton.jit
def _dsv4_indexer_score_batched_kernel(
score_ptr, # [num_tokens, kv_len], fp32
q_ptr, # [num_tokens, num_heads, head_dim]
kv_ptr, # [num_seqs, kv_len, head_dim]
weights_ptr, # [num_tokens, num_heads]
positions_ptr, # [num_tokens]
seq_ids_ptr, # [num_tokens]
kv_lens_ptr, # [num_seqs]
q_stride_t: tl.int64,
q_stride_h: tl.int64,
q_stride_d: tl.int64,
kv_stride_b: tl.int64,
kv_stride_t: tl.int64,
kv_stride_d: tl.int64,
weights_stride_t: tl.int64,
weights_stride_h: tl.int64,
score_stride_t: tl.int64,
score_stride_k: tl.int64,
num_heads: tl.constexpr,
head_dim: tl.constexpr,
kv_len_max: tl.constexpr,
ratio: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_D: tl.constexpr,
):
token_id = tl.program_id(0)
tile_id = tl.program_id(1)

seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32)
kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32)
offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T)
offs_d = tl.arange(0, BLOCK_D)
d_mask = offs_d < head_dim
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)

for h_start in range(0, num_heads, BLOCK_H):
offs_h = h_start + tl.arange(0, BLOCK_H)
h_mask = offs_h < num_heads

q = tl.load(
q_ptr
+ token_id * q_stride_t
+ offs_h[:, None] * q_stride_h
+ offs_d[None, :] * q_stride_d,
mask=h_mask[:, None] & d_mask[None, :],
other=0.0,
cache_modifier=".cg",
)
kv = tl.load(
kv_ptr
+ seq_id * kv_stride_b
+ offs_t[None, :] * kv_stride_t
+ offs_d[:, None] * kv_stride_d,
mask=(offs_t[None, :] < kv_len) & d_mask[:, None],
other=0.0,
cache_modifier=".cg",
)
dots = tl.dot(q, kv)
dots = tl.maximum(dots, 0.0)
weights = tl.load(
weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h,
mask=h_mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
acc += tl.sum(dots * weights[:, None], axis=0)

pos = tl.load(positions_ptr + token_id).to(tl.int32)
causal_limit = (pos + 1) // ratio
valid = (offs_t < kv_len_max) & (offs_t < kv_len) & (offs_t < causal_limit)
acc = tl.where(valid, acc, float("-inf"))
tl.store(
score_ptr + token_id * score_stride_t + offs_t * score_stride_k,
acc,
mask=offs_t < kv_len_max,
)


@triton.jit
def _dsv4_indexer_finalize_kernel(
out_ptr, # [num_tokens, topk], int32
values_ptr, # [num_tokens, topk], fp32
indices_ptr, # [num_tokens, topk], int64 from aiter topk
out_stride_t: tl.int64,
out_stride_k: tl.int64,
values_stride_t: tl.int64,
values_stride_k: tl.int64,
indices_stride_t: tl.int64,
indices_stride_k: tl.int64,
offset: tl.int32,
topk: tl.constexpr,
BLOCK_K: tl.constexpr,
):
token_id = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_K)
values = tl.load(
values_ptr + token_id * values_stride_t + offs_k * values_stride_k,
mask=offs_k < topk,
other=float("-inf"),
)
indices = tl.load(
indices_ptr + token_id * indices_stride_t + offs_k * indices_stride_k,
mask=offs_k < topk,
other=-1,
).to(tl.int32)
out = tl.where(values > -3.0e38, indices + offset, -1)
tl.store(
out_ptr + token_id * out_stride_t + offs_k * out_stride_k,
out,
mask=offs_k < topk,
)
Loading