Skip to content

Commit bf67fb1

Browse files
[Example] Optimize sink attention forward via swizzled layout and report benchmark results (#885)
* Enhance attention sink examples with swizzled layout and performance metrics - Added `make_swizzled_layout` annotations for shared tensors in the `flashattn` function across MHA and GQA examples to optimize memory access patterns. - Updated benchmark outputs to include speedup calculations comparing Triton and TileLang implementations. * Add README for Attention Sink example with algorithm details and benchmark results - Introduced a new README.md file for the Attention Sink example, outlining the forward and backward algorithms, including the computation of `dsinks`. - Provided benchmark results comparing performance metrics of the optimized implementation against Triton, highlighting speedup across various configurations. * Update README.md for Attention Sink example to include link to Triton implementation * Update examples/attention_sink/README.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * typo --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent c861d8a commit bf67fb1

File tree

4 files changed

+80
-8
lines changed

4 files changed

+80
-8
lines changed

examples/attention_sink/README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Attention Sink
2+
3+
We compare with an optimized version of the official Triton implementation at [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py).
4+
5+
6+
## Algorithm
7+
### Forward
8+
The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage.
9+
10+
### Backward
11+
Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks` additionally, which is given by:
12+
13+
$$
14+
dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q}
15+
$$
16+
17+
where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row).
18+
19+
## Benchmark of forward process
20+
21+
### Benchmark Environment
22+
- **Hardware**: NVIDIA H800
23+
- **CUDA version**: 12.9
24+
- **Triton Version**: 3.4.0
25+
26+
### Results
27+
28+
- dtype=float16
29+
- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B)
30+
- Full attention is adopted.
31+
32+
| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup |
33+
|---------|---------|---------------|----------------------|---------|
34+
| 2048 | 64 | 231.55 | **277.07** | 1.20x |
35+
| 2048 | 128 | 313.55 | **393.98** | 1.26x |
36+
| | | | | |
37+
| 4096 | 64 | 272.17 | **337.30** | 1.24x |
38+
| 4096 | 128 | 356.35 | **461.54** | 1.30x |
39+
| | | | | |
40+
| 8192 | 64 | 289.93 | **353.81** | 1.22x |
41+
| 8192 | 128 | 392.18 | **482.50** | 1.23x |
42+
| | | | | |
43+
| 16384 | 64 | 299.52 | **377.44** | 1.26x |
44+
| 16384 | 128 | 404.64 | **519.02** | 1.28x |
45+
46+
> The backward performance will be further optimized via fine-grained manual pipelining of FA3 in the tilelang kernel.

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tilelang.autotuner import autotune
77
from tilelang.profiler import do_bench
88
import tilelang.language as T
9+
from tilelang.layout import make_swizzled_layout
910
import itertools
1011
import argparse
1112
import triton
@@ -152,6 +153,13 @@ def main(
152153
logsum = T.alloc_fragment([block_M], accum_dtype)
153154
sinks = T.alloc_fragment([block_M], dtype)
154155

156+
T.annotate_layout({
157+
Q_shared: make_swizzled_layout(Q_shared),
158+
K_shared: make_swizzled_layout(K_shared),
159+
V_shared: make_swizzled_layout(V_shared),
160+
O_shared: make_swizzled_layout(O_shared),
161+
})
162+
155163
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
156164
T.fill(acc_o, 0)
157165
T.fill(logsum, 0)
@@ -425,22 +433,24 @@ def main(
425433
print("Checks for triton failed.❌")
426434

427435
# Benchmark triton
428-
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
429-
print("Triton: {:.2f} ms".format(latency))
430-
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
436+
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
437+
print("Triton: {:.2f} ms".format(latency_triton))
438+
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
431439

432440
# Benchmark tilelang
433-
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
434-
print("Tilelang: {:.2f} ms".format(latency))
435-
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
441+
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
442+
print("Tilelang: {:.2f} ms".format(latency_tilelang))
443+
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
444+
445+
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
436446

437447

438448
if __name__ == "__main__":
439449
parser = argparse.ArgumentParser()
440450
parser.add_argument('--batch', type=int, default=1, help='batch size')
441451
parser.add_argument('--heads', type=int, default=64, help='heads')
442-
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
443-
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
452+
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
453+
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
444454
parser.add_argument('--dim', type=int, default=128, help='dim')
445455
parser.add_argument('--groups', type=int, default=8, help='groups')
446456
parser.add_argument(

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tilelang.autotuner import autotune
66
from tilelang.profiler import do_bench
77
import tilelang.language as T
8+
from tilelang.layout import make_swizzled_layout
89
import itertools
910
import argparse
1011

@@ -140,6 +141,13 @@ def main(
140141
logsum = T.alloc_fragment([block_M], accum_dtype)
141142
sinks = T.alloc_fragment([block_M], dtype)
142143

144+
T.annotate_layout({
145+
Q_shared: make_swizzled_layout(Q_shared),
146+
K_shared: make_swizzled_layout(K_shared),
147+
V_shared: make_swizzled_layout(V_shared),
148+
O_shared: make_swizzled_layout(O_shared),
149+
})
150+
143151
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
144152
T.fill(acc_o, 0)
145153
T.fill(logsum, 0)

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tilelang.autotuner import autotune
77
from tilelang.profiler import do_bench
88
import tilelang.language as T
9+
from tilelang.layout import make_swizzled_layout
910
import itertools
1011
import argparse
1112
import triton
@@ -145,6 +146,13 @@ def main(
145146
logsum = T.alloc_fragment([block_M], accum_dtype)
146147
sinks = T.alloc_fragment([block_M], dtype)
147148

149+
T.annotate_layout({
150+
Q_shared: make_swizzled_layout(Q_shared),
151+
K_shared: make_swizzled_layout(K_shared),
152+
V_shared: make_swizzled_layout(V_shared),
153+
O_shared: make_swizzled_layout(O_shared),
154+
})
155+
148156
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
149157
T.fill(acc_o, 0)
150158
T.fill(logsum, 0)

0 commit comments

Comments
 (0)