Skip to content

[Cute,Fwd,Sm100] Add r2p for local mask#2185

Merged
v0i0 merged 12 commits intoDao-AILab:mainfrom
henrylhtsang:test_local_r2p
Jan 17, 2026
Merged

[Cute,Fwd,Sm100] Add r2p for local mask#2185
v0i0 merged 12 commits intoDao-AILab:mainfrom
henrylhtsang:test_local_r2p

Conversation

@henrylhtsang
Copy link
Copy Markdown
Contributor

@henrylhtsang henrylhtsang commented Jan 15, 2026

The idea is similar to mask_r2p.

Using xor mask_range = mask_right & ~ mask_left because:

upper bound: 0b00011111
lower bound: 0b00000001
upper & ~ lower : 0b00011110

Notes:

  • Original idea was to use xor. But that would require assuming window_size_left + window_size_right >= 0. I can't tell if xor is really better.
  • I wanted to avoid taking min(right_s, 24). But I would encounter an error OverflowError: Python int too large to convert to C long at mask_range = mask_right & ~mask_left

Benchmark

Before

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(512,512), varlen = False, deterministic = False ###
FA Python fwd: 0.304ms, 875.9 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(1024,1024), varlen = False, deterministic = False ###
FA Python fwd: 0.442ms, 1167.5 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(2048,2048), varlen = False, deterministic = False ###
FA Python fwd: 0.723ms, 1330.8 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(4096,4096), varlen = False, deterministic = False ###
FA Python fwd: 1.144ms, 1442.1 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(512,0), varlen = False, deterministic = False ###
FA Python fwd: 0.232ms, 574.0 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(1024,0), varlen = False, deterministic = False ###
FA Python fwd: 0.297ms, 869.3 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(2048,0), varlen = False, deterministic = False ###
FA Python fwd: 0.417ms, 1152.8 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(4096,0), varlen = False, deterministic = False ###
FA Python fwd: 0.635ms, 1299.5 TFLOPS

After

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(512,512), varlen = False, deterministic = False ###
FA Python fwd: 0.281ms, 948.1 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(1024,1024), varlen = False, deterministic = False ###
FA Python fwd: 0.426ms, 1209.6 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(2048,2048), varlen = False, deterministic = False ###
FA Python fwd: 0.708ms, 1358.8 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(4096,4096), varlen = False, deterministic = False ###
FA Python fwd: 1.131ms, 1458.5 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(512,0), varlen = False, deterministic = False ###
FA Python fwd: 0.206ms, 648.4 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(1024,0), varlen = False, deterministic = False ###
FA Python fwd: 0.275ms, 938.9 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(2048,0), varlen = False, deterministic = False ###
FA Python fwd: 0.400ms, 1202.3 TFLOPS

### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(4096,0), varlen = False, deterministic = False ###
FA Python fwd: 0.621ms, 1329.1 TFLOPS

test

pytest . -x 

============================================================================================ 49957 passed, 34769 skipped, 1325295 warnings in 20300.02s (5:38:20) ============================================================================================

Add mask_r2p_dual_bound function using XOR of two bitmasks
to efficiently mask elements outside [col_limit_left, col_limit_right)
range for SM100 local attention.
Add mask_r2p_dual_bound function using XOR of two bitmasks
to efficiently mask elements outside [col_limit_left, col_limit_right)
range for SM100 local attention.
@henrylhtsang
Copy link
Copy Markdown
Contributor Author

maybe @jayhshah @v0i0 ?

@henrylhtsang
Copy link
Copy Markdown
Contributor Author

henrylhtsang commented Jan 16, 2026

I compared xor vs mask_range = mask_right & ~ mask_left. They are pretty close, probably 0.5% difference, but xor seems faster.

updated to use out_bound. perf is better. nah

@tridao
Copy link
Copy Markdown
Member

tridao commented Jan 16, 2026

This is awesome!

@henrylhtsang henrylhtsang marked this pull request as draft January 16, 2026 02:13
@henrylhtsang henrylhtsang marked this pull request as ready for review January 16, 2026 04:45
@henrylhtsang henrylhtsang requested a review from v0i0 January 16, 2026 04:45
@henrylhtsang henrylhtsang marked this pull request as draft January 16, 2026 17:37
@henrylhtsang henrylhtsang marked this pull request as ready for review January 16, 2026 22:27
@henrylhtsang
Copy link
Copy Markdown
Contributor Author

@v0i0 noob question: I don't see a merge button. Can you help merge for me? tests just finished running

@v0i0 v0i0 merged commit 2d6b146 into Dao-AILab:main Jan 17, 2026
elewarr pushed a commit to elewarr/flash-attention that referenced this pull request Feb 4, 2026
[Cute,Fwd,Sm100] Add r2p for local mask
YangWang92 pushed a commit to YangWang92/flash-attention that referenced this pull request Feb 15, 2026
[Cute,Fwd,Sm100] Add r2p for local mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants