Skip to content

Optimize triton swa kernel by skipping computation#8860

Merged
ispobock merged 7 commits intomainfrom
ke/opt-triton-swa
Aug 6, 2025
Merged

Optimize triton swa kernel by skipping computation#8860
ispobock merged 7 commits intomainfrom
ke/opt-triton-swa

Conversation

@ispobock
Copy link
Copy Markdown
Collaborator

@ispobock ispobock commented Aug 6, 2025

Motivation

For triton sliding window attention, when mask for tile is all false, we can skip qk computation and kv loading directly.

Accuracy Test

gpqa 4k:

python3 -m sglang.launch_server --model-path lmsys/gpt-oss-20b-bf16
OPENAI_BASE_URL=http://localhost:30000/v1 OPENAI_API_KEY=dummy python -m simple-evals.simple_evals --eval gpqa --n-repeats 1

All results: 
| model_name                                              |   ('metric', 'gpqa') |
|:--------------------------------------------------------|---------------------:|
| o4-mini-with-chat-completion-and-4k-gen_20250806_023602 |             0.429293 |

mmlu 4k:

OPENAI_BASE_URL=http://localhost:30000/v1 OPENAI_API_KEY=dummy python -m simple-evals.simple_evals --model o4-mini-with-chat-completion-and-4k-gen --eval mmlu --examples 1000

All results: 
| model_name                                              |   ('metric', 'mmlu') |
|:--------------------------------------------------------|---------------------:|
| o4-mini-with-chat-completion-and-4k-gen_20250806_025747 |                 0.82 |

unit test: #8853

Benchmark & Profiling

python3 benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py

main branch:

     N_CTX  WINDOW_SIZE       torch     triton
0   1024.0         -1.0   12.259432   1.353153
1   1024.0        127.0   12.144628   1.552124
2   1024.0        256.0   12.163820   1.550772
3   1024.0        512.0   12.089032   1.551102
4   2048.0         -1.0   27.315083   4.956499
5   2048.0        127.0   27.476064   5.712122
6   2048.0        256.0   27.348768   5.706436
7   2048.0        512.0   27.343168   5.709613
8   4096.0         -1.0   83.944992  19.438477
9   4096.0        127.0   86.879005  22.288384
10  4096.0        256.0   86.774178  22.279416
11  4096.0        512.0   86.592064  22.300128
12  8192.0         -1.0  239.196320  74.590080
13  8192.0        127.0  249.880478  85.081985
14  8192.0        256.0  248.444382  85.094467
15  8192.0        512.0  246.717850  85.098846

this PR:

     N_CTX  WINDOW_SIZE       torch     triton
0   1024.0         -1.0   12.727996   1.350829
1   1024.0        127.0   13.316014   1.512761
2   1024.0        256.0   13.376210   1.683118
3   1024.0        512.0   15.210245   1.847432
4   2048.0         -1.0   28.888757   4.949422
5   2048.0        127.0   29.648736   4.529437
6   2048.0        256.0   28.753995   4.913037
7   2048.0        512.0   28.829088   5.590257
8   4096.0         -1.0   84.838463  19.435449
9   4096.0        127.0   88.039200  14.164604
10  4096.0        256.0   87.510719  15.041307
11  4096.0        512.0   87.975906  16.576083
12  8192.0         -1.0  240.588287  74.623230
13  8192.0        127.0  249.122406  46.699312
14  8192.0        256.0  248.449081  48.813841
15  8192.0        512.0  247.388519  52.387520

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial performance enhancement to the Triton-based sliding window attention (SWA) kernel. By intelligently bypassing computationally intensive operations for masked-out attention tiles, the changes aim to improve the efficiency and speed of attention computations, particularly in scenarios where sparse attention patterns are prevalent. The PR also includes a new benchmarking utility to validate these performance gains.

Highlights

  • Performance Optimization: Implemented a significant optimization in the Triton sliding window attention (SWA) kernel (_fwd_kernel) to conditionally skip qk (query-key) computation and kv (key-value) loading. This occurs when the attention mask for an entire tile indicates that all values would be invalid (masked out), preventing unnecessary calculations.
  • New Benchmark Script: Added a dedicated benchmark script (bench_triton_swa_kernel.py) to thoroughly evaluate the performance of the optimized Triton SWA kernel against a PyTorch reference implementation. This script facilitates measuring the real-world impact of the introduced optimizations across various sequence lengths and window sizes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes introduce an optimization in the triton swa kernel by skipping computation when the mask for a tile is all false. The review focuses on ensuring that the optimization is effective and that the SKIP_TILE variable is used correctly to avoid unnecessary computations.

@ispobock ispobock merged commit 0475448 into main Aug 6, 2025
4 of 56 checks passed
@ispobock ispobock deleted the ke/opt-triton-swa branch August 6, 2025 13:37
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants