From 025d17287f0914e07c659bf62084275cf503b535 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 26 Sep 2025 17:18:20 +0000 Subject: [PATCH 1/6] Enhance attention sink examples with swizzled layout and performance metrics - Added `make_swizzled_layout` annotations for shared tensors in the `flashattn` function across MHA and GQA examples to optimize memory access patterns. - Updated benchmark outputs to include speedup calculations comparing Triton and TileLang implementations. --- ...ample_gqa_sink_fwd_bhsd_wgmma_pipelined.py | 20 ++++++++++++++----- .../example_mha_sink_fwd_bhsd.py | 8 ++++++++ ...ample_mha_sink_fwd_bhsd_wgmma_pipelined.py | 8 ++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 7df0f32ef..79ed4ea88 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -6,6 +6,7 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T +from tilelang.layout import make_swizzled_layout import itertools import argparse import triton @@ -152,6 +153,13 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) + T.annotate_layout({ + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + }) + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -425,22 +433,24 @@ def main( print("Checks for triton failed.❌") # Benchmark triton - latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) - print("Triton: {:.2f} ms".format(latency)) - print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) + print("Triton: {:.2f} ms".format(latency_triton)) + print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9)) # Benchmark tilelang latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) print("Tilelang: {:.2f} ms".format(latency)) print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + print("Speedup: {:.2f}x".format(latency_triton / latency)) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument('--heads', type=int, default=64, help='heads') - parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') - parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') + parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') + parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument('--groups', type=int, default=8, help='groups') parser.add_argument( diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 45619782f..91af5fec1 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -5,6 +5,7 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T +from tilelang.layout import make_swizzled_layout import itertools import argparse @@ -140,6 +141,13 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) + T.annotate_layout({ + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + }) + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index 7de47fe9e..63801bcb6 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -6,6 +6,7 @@ from tilelang.autotuner import autotune from tilelang.profiler import do_bench import tilelang.language as T +from tilelang.layout import make_swizzled_layout import itertools import argparse import triton @@ -145,6 +146,13 @@ def main( logsum = T.alloc_fragment([block_M], accum_dtype) sinks = T.alloc_fragment([block_M], dtype) + T.annotate_layout({ + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + }) + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) From 7f90f4d5a657bd5f6b2e582c1447eb5a96ef0034 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 26 Sep 2025 17:21:14 +0000 Subject: [PATCH 2/6] Add README for Attention Sink example with algorithm details and benchmark results - Introduced a new README.md file for the Attention Sink example, outlining the forward and backward algorithms, including the computation of `dsinks`. - Provided benchmark results comparing performance metrics of the optimized implementation against Triton, highlighting speedup across various configurations. --- examples/attention_sink/README.md | 46 +++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 examples/attention_sink/README.md diff --git a/examples/attention_sink/README.md b/examples/attention_sink/README.md new file mode 100644 index 000000000..368b381f5 --- /dev/null +++ b/examples/attention_sink/README.md @@ -0,0 +1,46 @@ +# Attention Sink + +We compare with an optimized version of the official Triton implementation at + + +## Algorithm +### Forward +The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage. + +### Backward +Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks`, which is given by: + +$$ +dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q} +$$ + +where $P_{b, h, q}$ is the propotion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row). + +## Benchmark of forward process + +### Benchmark Environment +- **Hardware**: NVIDIA H800 +- **CUDA version**: 12.9 +- **Triton Version**: 3.4.0 + +### Results + +- dtype=float16 +- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B) +- Full attention is adopted. + +| SEQ_LEN | headdim | Triton TFLOPS | TileLang TFLOPS | Speedup | +|---------|---------|---------------|----------------------|---------| +| 2048 | 64 | 231.55 | **277.07** | 1.20x | +| 2048 | 128 | 313.55 | **393.98** | 1.26x | +| | | | | | +| 4096 | 64 | 272.17 | **337.30** | 1.24x | +| 4096 | 128 | 356.35 | **461.54** | 1.30x | +| | | | | | +| 8192 | 64 | 289.93 | **353.81** | 1.22x | +| 8192 | 128 | 392.18 | **482.50** | 1.23x | +| | | | | | +| 16384 | 64 | 299.52 | **377.44** | 1.26x | +| 16384 | 128 | 404.64 | **519.02** | 1.28x | + +> The backward performance will be further optimized via fine-grained manual pipelining of FA3 in the tilelang kernel. \ No newline at end of file From 3d8cb7ebc624cc8f44375c0a76d3ee80654a375f Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 26 Sep 2025 17:24:41 +0000 Subject: [PATCH 3/6] Update README.md for Attention Sink example to include link to Triton implementation --- examples/attention_sink/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/attention_sink/README.md b/examples/attention_sink/README.md index 368b381f5..efaac991c 100644 --- a/examples/attention_sink/README.md +++ b/examples/attention_sink/README.md @@ -1,6 +1,6 @@ # Attention Sink -We compare with an optimized version of the official Triton implementation at +We compare with an optimized version of the official Triton implementation at [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py). ## Algorithm From 7e75150e9082d969b1e47e3e08e5298e547496b9 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Sat, 27 Sep 2025 01:25:25 +0800 Subject: [PATCH 4/6] Update examples/attention_sink/README.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- examples/attention_sink/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/attention_sink/README.md b/examples/attention_sink/README.md index efaac991c..fcaebd726 100644 --- a/examples/attention_sink/README.md +++ b/examples/attention_sink/README.md @@ -14,7 +14,7 @@ $$ dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q} $$ -where $P_{b, h, q}$ is the propotion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row). +where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row). ## Benchmark of forward process From 5ddcd7f63f7b374c5779c02c61bdc4a7955587f0 Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Sat, 27 Sep 2025 01:25:40 +0800 Subject: [PATCH 5/6] Update examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../example_gqa_sink_fwd_bhsd_wgmma_pipelined.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 79ed4ea88..a54da604f 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -438,11 +438,11 @@ def main( print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9)) # Benchmark tilelang - latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) - print("Tilelang: {:.2f} ms".format(latency)) - print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency_tilelang)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) - print("Speedup: {:.2f}x".format(latency_triton / latency)) + print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang)) if __name__ == "__main__": From 08c0bf88285e62aca2b8405c857f09efe10b3107 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 26 Sep 2025 17:33:11 +0000 Subject: [PATCH 6/6] typo --- examples/attention_sink/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/attention_sink/README.md b/examples/attention_sink/README.md index fcaebd726..45d2f926c 100644 --- a/examples/attention_sink/README.md +++ b/examples/attention_sink/README.md @@ -8,7 +8,7 @@ We compare with an optimized version of the official Triton implementation at [h The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage. ### Backward -Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks`, which is given by: +Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks` additionally, which is given by: $$ dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q} @@ -29,7 +29,7 @@ where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th b - batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B) - Full attention is adopted. -| SEQ_LEN | headdim | Triton TFLOPS | TileLang TFLOPS | Speedup | +| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup | |---------|---------|---------------|----------------------|---------| | 2048 | 64 | 231.55 | **277.07** | 1.20x | | 2048 | 128 | 313.55 | **393.98** | 1.26x |