Skip to content

[WIP] feat: Add n_offset optimization for sliding window attention#36

Closed
zminglei wants to merge 4 commits intosgl-project:sgl-kernelfrom
zminglei:feat/swa-n-offset-optimization
Closed

[WIP] feat: Add n_offset optimization for sliding window attention#36
zminglei wants to merge 4 commits intosgl-project:sgl-kernelfrom
zminglei:feat/swa-n-offset-optimization

Conversation

@zminglei
Copy link

@zminglei zminglei commented Feb 27, 2026

Summary

Adapt the n_offset sliding window optimization from vllm for sgl-attn. This optimization shifts KV gmem pointers to the start of the sliding window and reduces seqlen_k, so the kernel only iterates over actual window blocks instead of the full KV sequence length.

Changes

Core optimization (block.h, mainloop_fwd_sm90_tma_gmma_ws.hpp)

  • Compute n_offset in get_n_block_min_max() to identify where the sliding window starts
  • Return n_offset as part of 3-tuple from get_n_block_min_max()
  • In load(): shift KV TMA pointers by n_offset so we start reading from the window start
  • In mma(): adjust seqlen_k by subtracting n_offset, inline mask helper functions for compatibility with const seqlen_k in SeqlenInfoQK
  • In mma_pv(): handle 3-tuple return from get_n_block_min_max()

Tile size tuning (tile_size.h)

  • Use kBlockN=160 with IntraWGOverlap=false for SWA with hdim<=64

API changes (flash_api.cpp)

  • Disable PagedKV TMA for is_local (required for n_offset pointer shifting)
  • Add PackGQA override for hdim=64, h/h_k=8, is_local

Template infrastructure (epilogue_fwd.hpp, flash_fwd_launch_template.h, static_switch.h)

  • Add kBlockH template parameter (infrastructure for future PackGQA_TMA support)
  • Add PACK_GQA_BLOCK_SWITCH macro
  • kBlockH forced to 1 for now (see Known Limitations)

Performance

Benchmarked on gpt-oss-120b (hdim=64, SWA=128, TP1, H200):

  • Before: 9.84ms TPOT
  • After: 9.35ms TPOT (~5% improvement)
  • Benchmark: 16 concurrent requests, 1K input tokens, 12K output tokens
  • Correctness verified: GSM8K accuracy 85.5% (matches baseline)

Known Limitations

  • kBlockH is forced to 1, which disables PackGQA_TMA. Enabling kBlockH > 1 requires validating sgl-attn epilogue with TMA-based Q/O access patterns (this caused correctness issues during testing).
  • The n_offset optimization only activates for is_local (sliding window) attention layers.

Test Plan

  • Build succeeds (ninja -j32, 372 targets)
  • GSM8K accuracy: 85.5% (verified correct output)
  • TPOT benchmark: 9.35ms (5% improvement over 9.84ms baseline)
  • vLLM same-node baseline: 9.12ms (remaining gap ~2.5%)

Adapt the n_offset optimization from vLLM flash-attention PR Dao-AILab#78 for sgl-attn.
This optimization shifts KV pointers to the start of the sliding window and reduces
seqlen_k, so the kernel only iterates over actual window blocks instead of the full
KV sequence length.

Key changes:
- block.h: Compute n_offset in get_n_block_min_max, return as 3-tuple
- mainloop_fwd_sm90_tma_gmma_ws.hpp: Shift KV TMA pointers by n_offset in load(),
  adjust seqlen_k in mma() with inlined mask helpers for const seqlen_k compatibility
- tile_size.h: Use kBlockN=160 with IntraWGOverlap=false for SWA (hdim<=64)
- flash_api.cpp: Disable PagedKV TMA for is_local, add PackGQA override
- epilogue_fwd.hpp: Add kBlockH template parameter (forced to 1 for now)
- flash_fwd_launch_template.h: Propagate kBlockH, add PACK_GQA_BLOCK_SWITCH
- static_switch.h: Add PACK_GQA_BLOCK_SWITCH macro

Performance: ~5% TPOT improvement on gpt-oss-120b (hdim=64, SWA=128, TP1, H200)
  - Before: 9.84ms -> After: 9.35ms (16 concurrent, 1K input, 12K output)

Note: kBlockH is forced to 1 (PackGQA_TMA disabled) pending correctness validation
of sgl-attn epilogue with TMA-based Q/O access patterns.
@zminglei zminglei changed the title feat: Add n_offset optimization for sliding window attention [WIP] feat: Add n_offset optimization for sliding window attention Feb 28, 2026
@zminglei zminglei force-pushed the feat/swa-n-offset-optimization branch from 1466d1b to 04c50fa Compare February 28, 2026 18:23
The tile_size.h changes (kBlockN=160, IntraWGOverlap=false for SWA)
caused accuracy regression. With tile_size.h reverted, all n_offset
optimization changes are preserved.
@zminglei zminglei force-pushed the feat/swa-n-offset-optimization branch from 04c50fa to 754571e Compare February 28, 2026 18:25
Keep kBlockN=160 for sliding window attention (better perf), but fix
IntraWGOverlap to true (was !is_local=false which caused accuracy
regression).
@zminglei
Copy link
Author

zminglei commented Mar 2, 2026

close this PR, moving to #37

@zminglei zminglei closed this Mar 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant