Skip to content

Add compress_factor for compressed causal attention#2418

Open
jduprat wants to merge 2 commits into
Dao-AILab:mainfrom
jduprat:compress-factor
Open

Add compress_factor for compressed causal attention#2418
jduprat wants to merge 2 commits into
Dao-AILab:mainfrom
jduprat:compress-factor

Conversation

@jduprat
Copy link
Copy Markdown
Contributor

@jduprat jduprat commented Mar 31, 2026

Add a compress_factor: int = 1 parameter to _flash_attn_fwd and
_flash_attn_bwd that adjusts the causal mask relationship from
kv_idx <= q_idx to kv_idx <= q_idx // compress_factor.

This enables native causal masking for compressed KV sequences where
each KV token represents compress_factor original tokens (e.g., in
sparse attention). Without this, users must supply a custom mask_mod
function, which is slower because it prevents blocks from being
classified as "full" — all blocks become "mask" blocks requiring
per-element evaluation.

seqlen_q cf compress_factor mask_mod speedup
2048 2 0.070 ms 0.094 ms 1.35x
4096 4 0.091 ms 0.118 ms 1.29x
8192 2 0.298 ms 0.370 ms 1.24x
16384 4 0.510 ms 0.616 ms 1.21x

jduprat added 2 commits March 31, 2026 14:51
Add a `compress_factor: int = 1` parameter to `_flash_attn_fwd` and
`_flash_attn_bwd` that adjusts the causal mask relationship from
`kv_idx <= q_idx` to `kv_idx <= q_idx // compress_factor`.

This enables native causal masking for compressed KV sequences where
each KV token represents `compress_factor` original tokens (e.g., in
sparse attention). Without this, users must supply a custom `mask_mod`
function, which is slower because it prevents blocks from being
classified as "full" — all blocks become "mask" blocks requiring
per-element evaluation.

Currently SM100/SM110 (Blackwell) only. When compress_factor=1 (default),
all code paths are identical to existing behavior.

Changes:
- interface.py: add compress_factor param to _flash_attn_fwd/_flash_attn_bwd,
  include in SM100 compile keys, pass to SM100 kernel constructors
- flash_fwd_sm100.py: accept compress_factor, pass to BlockInfo/AttentionMask
- flash_bwd_sm100.py: same as forward
- block_info.py: add compress_factor field; in get_n_block_min_max compute
  n_idx = m_idx_max // compress_factor + 1; in get_m_block_min_max compute
  m_idx = n_idx_min * compress_factor
- mask.py: add compress_factor field; in apply_mask/apply_mask_sm100 divide
  row_idx by compress_factor and remove seqlen_k-seqlen_q offset; in
  apply_mask_sm100_transposed scale causal_offset by compress_factor
Extend compress_factor plumbing to the SM90 (Hopper) forward and backward
kernels, matching the SM100 support from the previous commit.

The shared mask.py and block_info.py already handle compress_factor for
the SM90 code paths (apply_mask, get_n_block_min_max, get_m_block_min_max).
This commit adds the remaining wiring:

- flash_fwd.py (base class): accept compress_factor, store on self
- flash_fwd_sm90.py: pass compress_factor to BlockInfo and AttentionMask
- flash_bwd_sm90.py: accept compress_factor, pass to BlockInfo and AttentionMask
- interface.py: pass compress_factor to SM90 constructors; add to SM90 bwd compile key
@tridao
Copy link
Copy Markdown
Member

tridao commented Mar 31, 2026

Are there models / algorithms using this?
Adding this to the interface means we must continue to support this in the future to avoid breaking BC

@jduprat
Copy link
Copy Markdown
Contributor Author

jduprat commented Apr 1, 2026

Yes, I have an implementation of Sparse Attention (NSA) built of top of FA4 that benefits from this. I do realize the cost of carrying this forward.

@drisspg
Copy link
Copy Markdown
Collaborator

drisspg commented Apr 1, 2026

So this is an example of a static pattern that :might: benefit from vectorization, I have seem some interesting perf patterns with score_mod pytorch/pytorch#176055

let me explore: #2236 for masking as well

@jduprat
Copy link
Copy Markdown
Contributor Author

jduprat commented Apr 1, 2026

How can we best evaluate these options?

@jduprat
Copy link
Copy Markdown
Contributor Author

jduprat commented Apr 27, 2026

There are currently two different ways to express compressed-causal sparsity pattern (kv_idx * 64 <= q_idx). This PR makes the case that we need a 3rd.

KV is the compressed sequence (one entry per 64 original positions). A query at position q attends to compressed-KV index kv iff kv * 64 <= q. Constants used below: Q_TILE=256, KV_BLOCK=128, cf=64. Each query-tile is 256 consecutive q-positions; each kv-block is 128 consecutive compressed kv-positions (8192 original positions). So one kv-block ≈ 32 q-tiles wide along the diagonal.

  1. mask_mod
    User passes a cute.jit callable: lambda b, h, q, kv, ...: kv * 64 <= q.
    Kernel visits every kv-block for every q-tile (it doesn't know the structure → no tile skipping).
    On every visited element, evaluates the predicate (per-element mask op).
    Disables pack_gqa (the GQA-packing fast path can't reason about an opaque mask).
    Disables the causal R2P (register-to-predicate) fast path — that path is hardcoded for kv <= q, not arbitrary callables.
    Cost: O(N²/cf) compute and per-element predicate work on every visited element.

  2. AOT block-sparse
    User precomputes (block-granular) which kv-blocks each q-tile needs, and labels each kv-block as either full (predicate is true everywhere in the block) or partial (predicate is sometimes true → kernel still needs mask_mod on it).
    Kernel Skips kv-blocks not in either list → tile-skipping.
    For full blocks: no mask op (free).
    For partial blocks: calls mask_mod per element. Still disables pack_gqa (because mask_mod is in the loop for at least one block per tile).

  3. compress_factor (PR#2418)
    User passes one int: compress_factor=64. Kernel modifies the native causal predicate from kv <= q to kv * cf <= q.
    This is "causal with a stride" — it slots into the existing causal R2P fast path (boundary handled analytically in registers, no per-element mask call).
    Block classification (full / boundary / skip) is computed inside the kernel from (q_tile_id, kv_block_id, cf) — no input tensors.
    Per q-tile: identifies the diagonal kv-block analytically; full blocks are handled with no mask predicate at all; the diagonal is handled by R2P (a register-level mask, not a per-element callable invocation).
    pack_gqa works (no mask_mod in the loop).
    Backward goes through the same causal Q-direction-only path with no index tensors.

Cost on the user side: literally one int. No tensors to allocate, no transposed bwd layout, nothing to rebuild on shape changes.

mask_mod AOT block-sparse compress_factor
User input callable 4–8 tensors (fwd + bwd) 1 int
Tile skipping no yes (data-driven) yes (analytical)
Full-block fast path no yes yes
Diagonal/boundary block per-elem mask_mod per-elem mask_mod causal R2P (free)
pack_gqa disabled disabled enabled
Index-tensor memory 0 O(B·H·M·N) 0
Bwd needs separate user input n/a yes (Q-direction) no
Fwd vs mask_mod (N=1M) 3.21× 3.50×
Fwd vs AOT bs (N=1M) 1.09×
Bwd vs mask_mod (N=131K) 0.94× (slower) 1.04×
Bwd vs AOT bs (N=131K) 1.11×
image

Between the removal of the AOT fwd/bwd tensors, the preservation of pack_gqa, and the speedups, this adds real value. Besides AOT is slower in the bwd pass than mask_mod.

The benchmark used to produce this is in https://gist.github.com/jduprat/610cc5992559573a7f7557600f800c45

@drisspg
Copy link
Copy Markdown
Collaborator

drisspg commented Apr 27, 2026

now that we updated the original benchmark to correctly use blocksparse tensors not just mask_mod, the delta in perf IMO does not warrant the addition of this extra arguement,

I do want to find a nice mechanism for using user defined r2p + the blocksparse

@jduprat
Copy link
Copy Markdown
Contributor Author

jduprat commented Apr 28, 2026

The graph above was very low dpi, which made specifics hard to see. It was also log-scale which didn't help. Reattached so the perf gains are more clear —
bench_compress_factor
bench_compress_factor_linear

To summarize the performance gains —
– At least 9% faster than AOT in the forward pass
– At least 9% faster than mask_mod in the backward pass (AOT slower than mask_mod in backward pass)
– Enables pack_gqa which brings another level of perf itself
– Avoids the extra AOT tensors

Any suggestions for user defined r2p/blocksparse mechanisms to explore?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants