Skip to content

Commit 41d579e

Browse files
committed
2 parents 79b1941 + 91d5ef5 commit 41d579e

25 files changed

+1265
-306
lines changed

examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
3838
v += (bos * H + i_h) * V
3939
block_indices += (bos + i_t) * H * S + i_h * S
4040

41-
# if USE_BLOCK_COUNTS:
42-
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
43-
# else:
4441
NS = S
4542

4643
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
@@ -452,7 +449,12 @@ def get_configs():
452449

453450

454451
@tilelang.autotune(configs=get_configs(),)
455-
@tilelang.jit
452+
@tilelang.jit(
453+
pass_configs={
454+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
455+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
456+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
457+
})
456458
def tilelang_sparse_attention(batch,
457459
heads,
458460
seq_len,

examples/deepseek_nsa/example_tilelang_nsa_bwd.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
import tilelang
1818

1919

20-
@tilelang.jit(pass_configs={
21-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
22-
})
20+
@tilelang.jit(
21+
pass_configs={
22+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
23+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
24+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
25+
})
2326
def tilelang_kernel_fwd(
2427
batch,
2528
heads,

examples/deepseek_nsa/example_tilelang_nsa_fwd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010

1111
@tilelang.jit(
12-
out_idx=[-1], pass_configs={
12+
out_idx=[-1],
13+
pass_configs={
1314
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
15+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
16+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
1417
})
1518
def native_sparse_attention(batch,
1619
heads,

examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from einops import rearrange
1717

1818

19-
@tilelang.jit(pass_configs={
20-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
21-
})
19+
@tilelang.jit(
20+
pass_configs={
21+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
22+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
23+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
24+
})
2225
def native_sparse_attention_varlen(batch,
2326
heads,
2427
c_seq_len,

examples/deepseek_v32/README.md

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ deepseek_v32/
66
├── figures/ # Figures and diagrams
77
├── inference/ # Inference implementation folder
88
├── fp8_lighting_indexer.py # FP8 lighting indexer
9+
├── sparse_mla_bwd.py # Sparse MLA backward implementation
910
├── sparse_mla_fwd.py # Sparse MLA forward implementation
1011
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
1112
├── topk_selector.py # Top-k selector implementation
@@ -21,7 +22,7 @@ The architecture diagram above highlights three key components (shown in green)
2122

2223
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2324
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
24-
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass
25+
3. **Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes
2526

2627
### Lightning Indexer
2728

@@ -166,3 +167,57 @@ for i_i in T.serial(T.ceildiv(NI, 2)):
166167
```
167168

168169
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
170+
171+
### Sparse MLA Backward
172+
173+
The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity.
174+
175+
The backward pass consists of three main stages:
176+
177+
**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient):
178+
179+
```python
180+
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
181+
T.copy(O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o)
182+
T.copy(dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], do)
183+
for i, j in T.Parallel(block_ND, block_ND):
184+
acc[i, j] += o[i, j] * do[i, j]
185+
T.reduce_sum(acc, delta, 1)
186+
```
187+
188+
**2. Main Backward Computation**: Computes gradients through sparse attention:
189+
190+
```python
191+
# Sparse MLA backward: iterate over selected indices only
192+
for i_i in T.Pipelined(NI, num_stages=num_stages):
193+
# Load KV data for selected indices
194+
for bi_i, d_i in T.Parallel(BI, D):
195+
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i]
196+
197+
# Recompute attention scores for backward
198+
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
199+
200+
# Apply softmax gradient: dP = P * (dP_raw - Delta)
201+
for h_i, bi_i in T.Parallel(padded_H, BI):
202+
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
203+
```
204+
205+
The key gradient computations are:
206+
- **dQ = dP @ K** (query gradients)
207+
- **dK = dP^T @ Q** (key gradients)
208+
- **dV = P^T @ dO** (value gradients)
209+
210+
**3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation:
211+
212+
```python
213+
# Atomically update dKV at selected indices
214+
for bi_i, d_i in T.Parallel(BI // split_store, D // 4):
215+
T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4],
216+
acc_dkv_shared[bi_i, d_i * 4])
217+
```
218+
219+
**Performance**: The sparse MLA backward achieves excellent performance:
220+
- **H800 SXM**: ~100 TFlops
221+
- **H200 SXM**: ~115 TFlops
222+
223+
The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization.

0 commit comments

Comments
 (0)