Add compress_factor for compressed causal attention#2418
Conversation
Add a `compress_factor: int = 1` parameter to `_flash_attn_fwd` and `_flash_attn_bwd` that adjusts the causal mask relationship from `kv_idx <= q_idx` to `kv_idx <= q_idx // compress_factor`. This enables native causal masking for compressed KV sequences where each KV token represents `compress_factor` original tokens (e.g., in sparse attention). Without this, users must supply a custom `mask_mod` function, which is slower because it prevents blocks from being classified as "full" — all blocks become "mask" blocks requiring per-element evaluation. Currently SM100/SM110 (Blackwell) only. When compress_factor=1 (default), all code paths are identical to existing behavior. Changes: - interface.py: add compress_factor param to _flash_attn_fwd/_flash_attn_bwd, include in SM100 compile keys, pass to SM100 kernel constructors - flash_fwd_sm100.py: accept compress_factor, pass to BlockInfo/AttentionMask - flash_bwd_sm100.py: same as forward - block_info.py: add compress_factor field; in get_n_block_min_max compute n_idx = m_idx_max // compress_factor + 1; in get_m_block_min_max compute m_idx = n_idx_min * compress_factor - mask.py: add compress_factor field; in apply_mask/apply_mask_sm100 divide row_idx by compress_factor and remove seqlen_k-seqlen_q offset; in apply_mask_sm100_transposed scale causal_offset by compress_factor
Extend compress_factor plumbing to the SM90 (Hopper) forward and backward kernels, matching the SM100 support from the previous commit. The shared mask.py and block_info.py already handle compress_factor for the SM90 code paths (apply_mask, get_n_block_min_max, get_m_block_min_max). This commit adds the remaining wiring: - flash_fwd.py (base class): accept compress_factor, store on self - flash_fwd_sm90.py: pass compress_factor to BlockInfo and AttentionMask - flash_bwd_sm90.py: accept compress_factor, pass to BlockInfo and AttentionMask - interface.py: pass compress_factor to SM90 constructors; add to SM90 bwd compile key
|
Are there models / algorithms using this? |
|
Yes, I have an implementation of Sparse Attention (NSA) built of top of FA4 that benefits from this. I do realize the cost of carrying this forward. |
|
So this is an example of a static pattern that :might: benefit from vectorization, I have seem some interesting perf patterns with score_mod pytorch/pytorch#176055 let me explore: #2236 for masking as well |
|
How can we best evaluate these options? |
|
There are currently two different ways to express compressed-causal sparsity pattern (kv_idx * 64 <= q_idx). This PR makes the case that we need a 3rd. KV is the compressed sequence (one entry per 64 original positions). A query at position q attends to compressed-KV index kv iff kv * 64 <= q. Constants used below: Q_TILE=256, KV_BLOCK=128, cf=64. Each query-tile is 256 consecutive q-positions; each kv-block is 128 consecutive compressed kv-positions (8192 original positions). So one kv-block ≈ 32 q-tiles wide along the diagonal.
Cost on the user side: literally one int. No tensors to allocate, no transposed bwd layout, nothing to rebuild on shape changes.
Between the removal of the AOT fwd/bwd tensors, the preservation of pack_gqa, and the speedups, this adds real value. Besides AOT is slower in the bwd pass than mask_mod. The benchmark used to produce this is in https://gist.github.com/jduprat/610cc5992559573a7f7557600f800c45 |
|
now that we updated the original benchmark to correctly use blocksparse tensors not just mask_mod, the delta in perf IMO does not warrant the addition of this extra arguement, I do want to find a nice mechanism for using user defined r2p + the blocksparse |



Add a
compress_factor: int = 1parameter to_flash_attn_fwdand_flash_attn_bwdthat adjusts the causal mask relationship fromkv_idx <= q_idxtokv_idx <= q_idx // compress_factor.This enables native causal masking for compressed KV sequences where
each KV token represents
compress_factororiginal tokens (e.g., insparse attention). Without this, users must supply a custom
mask_modfunction, which is slower because it prevents blocks from being
classified as "full" — all blocks become "mask" blocks requiring
per-element evaluation.