[WIP] feat: Add n_offset optimization for sliding window attention#36
Closed
zminglei wants to merge 4 commits intosgl-project:sgl-kernelfrom
Closed
[WIP] feat: Add n_offset optimization for sliding window attention#36zminglei wants to merge 4 commits intosgl-project:sgl-kernelfrom
zminglei wants to merge 4 commits intosgl-project:sgl-kernelfrom
Conversation
Adapt the n_offset optimization from vLLM flash-attention PR Dao-AILab#78 for sgl-attn. This optimization shifts KV pointers to the start of the sliding window and reduces seqlen_k, so the kernel only iterates over actual window blocks instead of the full KV sequence length. Key changes: - block.h: Compute n_offset in get_n_block_min_max, return as 3-tuple - mainloop_fwd_sm90_tma_gmma_ws.hpp: Shift KV TMA pointers by n_offset in load(), adjust seqlen_k in mma() with inlined mask helpers for const seqlen_k compatibility - tile_size.h: Use kBlockN=160 with IntraWGOverlap=false for SWA (hdim<=64) - flash_api.cpp: Disable PagedKV TMA for is_local, add PackGQA override - epilogue_fwd.hpp: Add kBlockH template parameter (forced to 1 for now) - flash_fwd_launch_template.h: Propagate kBlockH, add PACK_GQA_BLOCK_SWITCH - static_switch.h: Add PACK_GQA_BLOCK_SWITCH macro Performance: ~5% TPOT improvement on gpt-oss-120b (hdim=64, SWA=128, TP1, H200) - Before: 9.84ms -> After: 9.35ms (16 concurrent, 1K input, 12K output) Note: kBlockH is forced to 1 (PackGQA_TMA disabled) pending correctness validation of sgl-attn epilogue with TMA-based Q/O access patterns.
1466d1b to
04c50fa
Compare
The tile_size.h changes (kBlockN=160, IntraWGOverlap=false for SWA) caused accuracy regression. With tile_size.h reverted, all n_offset optimization changes are preserved.
04c50fa to
754571e
Compare
Keep kBlockN=160 for sliding window attention (better perf), but fix IntraWGOverlap to true (was !is_local=false which caused accuracy regression).
Author
|
close this PR, moving to #37 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adapt the n_offset sliding window optimization from vllm for sgl-attn. This optimization shifts KV gmem pointers to the start of the sliding window and reduces seqlen_k, so the kernel only iterates over actual window blocks instead of the full KV sequence length.
Changes
Core optimization (block.h, mainloop_fwd_sm90_tma_gmma_ws.hpp)
n_offsetinget_n_block_min_max()to identify where the sliding window startsget_n_block_min_max()load(): shift KV TMA pointers by n_offset so we start reading from the window startmma(): adjust seqlen_k by subtracting n_offset, inline mask helper functions for compatibility with const seqlen_k inSeqlenInfoQKmma_pv(): handle 3-tuple return fromget_n_block_min_max()Tile size tuning (tile_size.h)
API changes (flash_api.cpp)
Template infrastructure (epilogue_fwd.hpp, flash_fwd_launch_template.h, static_switch.h)
Performance
Benchmarked on gpt-oss-120b (hdim=64, SWA=128, TP1, H200):
Known Limitations
kBlockHis forced to 1, which disables PackGQA_TMA. Enabling kBlockH > 1 requires validating sgl-attn epilogue with TMA-based Q/O access patterns (this caused correctness issues during testing).is_local(sliding window) attention layers.Test Plan