Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
670 commits
Select commit Hold shift + click to select a range
bba578d
Fix IMA in fwd on m boundary (#2091)
drisspg Dec 20, 2025
ceb4110
Update to dsl 3.4.3 (#2092)
drisspg Dec 22, 2025
5663adf
README for AMD ROCm (#2068)
seungrokj Dec 23, 2025
58fe37f
fix shuffle sync for pack gqa epilogue (#2097)
jayhshah Dec 24, 2025
11b32fd
improve paged cpasync
v0i0 Dec 24, 2025
d234051
Enable Thor (#2108)
johnnynunez Dec 29, 2025
4fd123e
[Cute] Add quack as dependency
tridao Dec 31, 2025
f3423a8
[Cute,Fwd,Sm90] Change PipelineTMAAsync sublass to signal per warp
tridao Jan 1, 2026
9b6dbac
Add pack-gqa support for blcoksparse impl w/ braodcasted H dim (#2098)
drisspg Jan 4, 2026
f98d345
[Cute,Fwd] improved block sparsity (#2100)
reubenconducts Jan 5, 2026
bb2efb3
[Cute] Fix minor lint issue in shuffle_sync
tridao Jan 5, 2026
f472175
Misc tests that should be xfailed for now (#2127)
drisspg Jan 5, 2026
3e87e42
Update cutlass to fix undefined symbol: cuDriverGetVersion. (#2142)
HydraQYH Jan 7, 2026
3c8ca4e
[Cute,Fwd,Sm100] Support `q_stage=1` for inference (#1993)
timmy-feng Jan 8, 2026
6dd7e74
[Cute] Fix two tests that were failing (#2149)
henrylhtsang Jan 8, 2026
c15ffe3
cleanup
v0i0 Jan 8, 2026
ed6a82f
[Cute, Bwd, Sm100] Add varlen for sm100 bwd (#2150)
jayhshah Jan 9, 2026
27a3b54
block-sparse backward SM90 (#2136)
drisspg Jan 10, 2026
844b10f
score-mod backward SM90 (#2137)
drisspg Jan 10, 2026
e317aa4
[Cute] Clarify and fix subtle cachekey bug (#2143)
drisspg Jan 10, 2026
26d4ee9
[CUTE][SM100] Fix backward gqa on sm100 post mask-mod semantic change…
drisspg Jan 10, 2026
8eff546
[CUTE][SM90]Enable pack-gqa with broadcasted maskmods (#2145)
drisspg Jan 10, 2026
5d4c953
[CUTE][SM90] GQA backward non deterministic (#2158)
drisspg Jan 10, 2026
ea8f735
[Cute,Bwd,Sm100] fix seqused in varlen bwd (#2167)
jayhshah Jan 10, 2026
ef7343b
[CUTE] Bump cutedsl to 4.3.5 (#2170)
drisspg Jan 12, 2026
dbf08eb
Merge pull request #2156 from v0i0/v0i0/improve-paged-ldgsts
v0i0 Jan 12, 2026
4cb272e
[Cute,Flex] Add option to create and cache __cute_hash__ (#2171)
reubenconducts Jan 12, 2026
4894657
[Cute][Flex] Remove no longer needed contig (#2172)
drisspg Jan 12, 2026
13696f2
[Cute] update row_max before safe overwrite for online_softmax (#2174)
jayhshah Jan 13, 2026
506441a
[Cute][Flex] add back in contig (#2177)
drisspg Jan 15, 2026
68649fb
[Cute][Flex]Add pack-gqa divmod (#2180)
drisspg Jan 15, 2026
88067b0
baseline local flops
henrylhtsang Jan 15, 2026
fffabc3
[Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)
timmy-feng Jan 15, 2026
a512bd8
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
2020964
remove benchmark result, undo changes to benchmark
henrylhtsang Jan 15, 2026
7108d1c
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
e4ec1ad
switch from xor to mask_right & ~ mask_left
henrylhtsang Jan 16, 2026
ac88858
flip in_bound to out_bound
henrylhtsang Jan 16, 2026
e34d840
remove zero logic for right_s and left_s
henrylhtsang Jan 16, 2026
08e6518
remove 24 clamp
henrylhtsang Jan 16, 2026
94f0348
doc
henrylhtsang Jan 16, 2026
e94012a
lint
henrylhtsang Jan 16, 2026
2e6ae05
added back clamp to avoid "OverflowError: Python int too large to con…
henrylhtsang Jan 16, 2026
137ad8e
add comment
henrylhtsang Jan 16, 2026
2d6b146
Merge pull request #2185 from henrylhtsang/test_local_r2p
v0i0 Jan 17, 2026
a0f9f41
[Cute][Flex] Fix expanded tensor bug (#2189)
drisspg Jan 17, 2026
04e6ee1
[Cute, SM90] fix fwd varlen Cute implementation bug for H100 (#2194)
KareemMusleh Jan 20, 2026
f15ccf5
reduce chance of build oom (#2079)
Qubitium Jan 21, 2026
2580b5a
[Cute][Flex] Allow q_offset 1 and add block-sizes to disambiguate edg…
drisspg Jan 22, 2026
57cef6c
ci: Use 1 ninja job for cu13 (#2195)
ko3n1g Jan 24, 2026
438325c
Update README to include 'psutil' package as build requirement (#2210)
wanglc02 Jan 25, 2026
4f89246
[Flex][SM100] Replay expand fix on sm100 (#2209)
drisspg Jan 26, 2026
99589e5
[DSL] Optionally patch cute-dsl to use system's ptxas
tridao Jan 27, 2026
701ebe0
[AMD] Triton Backend for ROCm #3 (#2178)
micmelesse Jan 28, 2026
514e63c
fix compute_block_sparsity usage in benchmark_mask_mod (#2221)
zhuochenKIDD Feb 2, 2026
188643b
Fix shared-memory race (#2229)
drisspg Feb 4, 2026
ef9e6a6
Use TORCH_TARGET_VERSION over TORCH_STABLE_ONLY (#2155)
janeyx99 Feb 4, 2026
24445c0
short readme for flex flash (#2231)
v0i0 Feb 5, 2026
e2743ab
[FA3] Mark current main version as v3.0.0 stable (#2223)
lw Feb 5, 2026
f1284cf
hdim 192 smem fix (#2235)
jayhshah Feb 5, 2026
912c6c4
Add `FLASH_ATTENTION_TRITON_AMD_CONFIG_JSON` env var support (#2239)
alexheretic Feb 7, 2026
abaa878
[CUTE]Bump to Cutedsl (#2216)
drisspg Feb 8, 2026
48af662
pytest-dist round robin to gpus (#2241)
drisspg Feb 8, 2026
a804a5a
[DSL] Replace old fence with cute.arch.fence_view_async_shared()
tridao Feb 8, 2026
5a66f2c
[DSL]Replace utils.{fma,mul,add}_packed_f32x2 with cute.arch version
tridao Feb 8, 2026
d39b629
[DSL] Remove coord_offset_i64, domain_offset_i64, elem_pointer_i64
tridao Feb 8, 2026
81f2c2d
[Sm90] Use functions from quack.sm90_utils
tridao Feb 8, 2026
7edcf59
[DSL] Use cute.arch.warp_reduction_{max,sum}
tridao Feb 8, 2026
b735ef2
[Layout] Use reshape_acc_to_mn and reshape_acc_to_frgA from quack
tridao Feb 8, 2026
8dd8019
[Layout] Use quack.layout_utils.mma_partition_C_vec
tridao Feb 8, 2026
90f10fa
[DSL] Use cute.math.{exp2,log2,log}
tridao Feb 8, 2026
b9148ce
[Layout] Use layout_utils.transpose_view and select from quack
tridao Feb 8, 2026
c912a37
[Bwd,Sm90] Use quack.copy_utils
tridao Feb 8, 2026
deb1830
[Bwd,Sm100] Shorten PipelineTmaUmma create
tridao Feb 8, 2026
17d2943
[Bwd,Sm90] Have score_mod and score_mod_bwd as partial functions
tridao Feb 8, 2026
2a8d39c
[DSL] warpgroup_reg_alloc -> setmaxregister_increase
tridao Feb 8, 2026
72c7ba4
Fix Hopper tests (#2242)
drisspg Feb 8, 2026
a5856bf
[Bwd,Sm90] For dQ, move wait_group before TMA atomic add
tridao Feb 11, 2026
c4d8b06
[Cute,Flex,Fwd] Allow vectorized score_mod definitions (#2236)
reubenconducts Feb 11, 2026
16d16d8
[Bwd,Sm90] Simplify dK/dV R2S copy
tridao Feb 14, 2026
ad2f470
[DSL] Use quack.cute_dsl_utils.ParamsBase
tridao Feb 14, 2026
b62d93f
Fix int32 overflow (#2260)
drisspg Feb 15, 2026
fec3a6a
[Cute][Flex] Fix kernel hang w/ multiple empty tiles (#2258)
drisspg Feb 16, 2026
a8780f2
Bump to 4.4.0 cute dsl pin (#2262)
drisspg Feb 18, 2026
710d3cc
BWD sm100 2cta (#2202)
tzadouri Feb 20, 2026
6079a9b
[Bwd,Sm100] Fix num reg variables
tridao Feb 20, 2026
05eea8b
[Cute] Change compute_capability to arch
tridao Feb 20, 2026
884f72d
[Bwd,Postprocess] Update api to cute.arch.fence_view_async_shared
tridao Feb 21, 2026
8e0b5d7
[Fwd,Sm100] Disable ex2 emulation for Sm103
tridao Feb 21, 2026
fe878cc
[Dep] Update quack dependency to 0.2.10
tridao Feb 21, 2026
463623e
[Fwd,Sm100] Use arch from BaseDSL._get_dsl().get_arch_enum()
tridao Feb 21, 2026
5caef45
[Fwd,Sm100] Clean up
tridao Feb 22, 2026
7c9981e
[Bwd,Sm100] Put 2CTA asserts under if const_expr
tridao Feb 22, 2026
d5515cb
[Fwd,Sm100] Refactor _store_O_to_gemm into a separate method
tridao Feb 22, 2026
3dd5d83
[Fwd,Sm100] Simplify tensor layouts
tridao Feb 22, 2026
6287355
[Fwd,Sm100] Use pipeline_kv in load_KV instead of raw mbarrier
tridao Feb 23, 2026
9136b0c
[DSL] Don't need to parse swizzle from str anymore
tridao Feb 23, 2026
8d9e28b
[Fwd,Sm100] Use position_independent for sO, more clean up
tridao Feb 23, 2026
a595ceb
[Fwd,Sm100] Use pipeline abstraction for loading Q and KV
tridao Feb 23, 2026
5678dd9
[Cute] Handle window_size=(-1, -1) for non-local attention (#2251)
henrylhtsang Feb 23, 2026
0ba6f22
Document usage with 🤗 Kernels (#2272)
sayakpaul Feb 25, 2026
156137b
[Cute,Sm100,Bwd] Add hdim 192 hdimv 128 backward for sm100 (#2270)
jayhshah Feb 25, 2026
2c0f11e
[Fwd,Sm100] Only 1 thread per warp signals mbar_P_full_2
tridao Feb 23, 2026
01a8b74
[Fwd,Sm100] Use pipeline abstraction for S_full & P_full_O_rescaled
tridao Feb 25, 2026
405df75
[Fwd,Sm100] Use pipeline abstraction for softmax-correction mbarrier
tridao Feb 25, 2026
e0bc9ca
[Fwd,Sm100] Use pipeline abstraction for correction-epilogue
tridao Feb 25, 2026
76d7362
[Fwd,Sm100] Tune registers
tridao Feb 25, 2026
484a5dc
Correct cutlass error handling (#2273)
ankutalev Feb 25, 2026
0586d2e
guard use_2cta_instrs on sm90 (#2274)
reubenconducts Feb 25, 2026
5963594
[cute] Add return_lse (#2271)
erikwijmans Feb 26, 2026
ffbc678
[Fwd,Sm100] Use pipeline abstraction for O_full
tridao Feb 25, 2026
cf027a4
[Fwd,Sm100] Use pipeline abstraction for mbar_P_full_2
tridao Feb 26, 2026
ed85ed7
[Fwd,Sm100] Use TmemAllocator
tridao Feb 26, 2026
0293155
[Fwd,Sm100] Set split_P_arrive as a tunable parameter
tridao Feb 26, 2026
bf4d8ee
[Fwd,Sm100] Use pipeline abstraction for s0_s1_sequence
tridao Feb 26, 2026
aa5f7db
[Fwd,Sm100] Fix tScS partitioning for score_mod
tridao Feb 26, 2026
944e457
fix mask mod bugs (#2276)
reubenconducts Feb 26, 2026
a00ddeb
[Cute,Sm100,Bwd] Fix and enable 2CTA path for hdim 128 backward (#2280)
jayhshah Feb 28, 2026
01bc8ef
[Fwd,Sm100] Change layout of gQ and gO to have q_stage
tridao Feb 27, 2026
d1d3e8d
[Fwd,Sm100] Pass cta_layout_vmnk to pipelines
tridao Feb 27, 2026
58d0c57
[Fwd,Sm100] Gate mma with is_leader_cta
tridao Feb 27, 2026
a631802
[Fwd,Sm100] Take into account mma_tile_coord_v when reading/writing
tridao Feb 27, 2026
b936061
[Fwd,Sm100] Add pipeline.producer_tail
tridao Feb 27, 2026
9aadb8b
[Fwd,Sm100] Enable 2CTA for hdim128 noncausal
tridao Feb 28, 2026
7ed0898
Bump to 4.4.1 to avoid segfault (#2291)
drisspg Feb 28, 2026
6d36c1c
Fix sm100 fwd missing tSrQs init regression (#2293)
drisspg Mar 1, 2026
d146eff
[Scheduler] Revert SingleTileScheduler to get block_idx
tridao Mar 1, 2026
ceb1099
Fix clang parser error of missing 'typename' prior to dependent type …
tomflinda Mar 2, 2026
be76c60
[CuTe] Include broadcast dims in backward compile cache keys (#2298)
bonpyt Mar 3, 2026
d78c84a
[Fwd,Sm100] Use NamedBarrier to signal softmax -> corr warps
tridao Mar 3, 2026
990b510
[Fwd,Sm100] Add polynomials degree 1 - 5
tridao Mar 3, 2026
72eb5de
[Fwd,Sm100] Switch back to poly degree 3
tridao Mar 3, 2026
51b6575
[Fwd,Sm100] Compute kv_stage based on hdim instead of hard-coding
tridao Mar 3, 2026
f2682b6
[Cute][Testing] Add fake tensor mode support for compile-only test pa…
Alkaid-Benetnash Mar 3, 2026
9d871f9
Enable hdim=96 bwd (#2302)
v0i0 Mar 3, 2026
4d9c722
Fix GQA crash in cute FLASH backend: init load_Q before conditional (…
platers Mar 3, 2026
884a52a
[Fwd,Sm100] Be more explicit when loading Q
tridao Mar 3, 2026
dd15c02
[Fwd,Sm100] Tune ex2_emu_freq
tridao Mar 3, 2026
c799762
[Fwd,Sm100] Tweak ptx for gemm
tridao Mar 3, 2026
0d943f8
[Bench] Enable benchmarking bwd with headdim != headdim_v
tridao Mar 3, 2026
2b5db43
fix paged kv (#2303)
jayhshah Mar 3, 2026
d51a4a1
Add FA4 publishing strategy (#2282)
drisspg Mar 3, 2026
9a25eba
[Cute][Testing] Add persistent compile cache for cutedsl AOT compilat…
Alkaid-Benetnash Mar 4, 2026
1b2a6cd
[Bench] Add reference attn implementation
tridao Mar 5, 2026
a79ee34
[Bwd,Sm100] Use TmemAllocator
tridao Mar 5, 2026
a365a19
Change PyPI name to flash-attn4
tridao Mar 5, 2026
253ecf5
Try to publish to PyPI again
tridao Mar 5, 2026
dc754c7
Try again
tridao Mar 5, 2026
3e643ef
Change PyPI package name to fa4
tridao Mar 5, 2026
120b306
Add fa4_paper.pdf
tridao Mar 5, 2026
5ded17f
Add DEBUG_2CTA.md
tridao Mar 5, 2026
d91ea94
[Bwd,Sm100] Add fence_view_async_shared before LSE release
tridao Mar 5, 2026
fbe1568
Change PyPI name back to flash-attn-4
tridao Mar 5, 2026
8a7af50
[Cute][Testing] Minor improvements on pytest-xdist workflow (#2311)
Alkaid-Benetnash Mar 6, 2026
47a5899
[AI] Racecheck cp.async.bulk false positive investigation
tridao Mar 7, 2026
d337e64
[AI] Update racecheck false positive reproducer
tridao Mar 7, 2026
8451d4e
[Bwd,Sm103] Fix postprocess for 2cta_instrs
tridao Mar 7, 2026
dd8a272
[Sm100] Fix tmem delloc: sync before dealloc
tridao Mar 8, 2026
13bd5e6
[Test] Skip non-files in cache_utils.py
tridao Mar 8, 2026
817a1a0
Add more code authors
tridao Mar 8, 2026
485e8fe
[AI] Add CLAUDE.md
tridao Mar 8, 2026
0f375dd
Nicer headdim error message (#2227)
drisspg Mar 9, 2026
7b1581f
[Fwd,Sm100] Extract named barriers (#2309)
drisspg Mar 9, 2026
42c5765
Change 2cta opt in to have min seqlen > 2*m_block_size (#2320)
drisspg Mar 9, 2026
e3dd324
[CuteDSL][SM90] varlen bwd works (#2275)
KareemMusleh Mar 10, 2026
4d65522
[Fwd,Sm90] Move FwdSm90 to a separate file
tridao Mar 8, 2026
b3dd9c3
[GQA] Refactor pack_gqa_layout into a helper function
tridao Mar 8, 2026
a941b76
[Fwd] Refactor compute_softmax_scale_log2 and comptue_fastdiv_mods
tridao Mar 8, 2026
7fd16f2
[GQA] Add unpack_gqa_layout
tridao Mar 8, 2026
c04a808
[CI] Drop Pytorch 2.5, add Pytorch 2.10, add CUDA 13
tridao Mar 10, 2026
1314ea2
Add Logging helper (#2327)
drisspg Mar 11, 2026
4d2be70
[Sm80] basic fix for new api (#2297)
zhuochenKIDD Mar 11, 2026
80ff9a9
fix: duplicate softmax_scale param (#2328)
NanoCode012 Mar 11, 2026
706ae51
[Bwd] Compile bwd_preprocess with cute fake tensors
tridao Mar 10, 2026
6de65f3
[Bwd] Clean up bwd_preprocess kernel
tridao Mar 10, 2026
bdf123b
[Fwd] Port SeqlenInfoQKNewK from C++ to cute-dsl
tridao Mar 11, 2026
031b178
[Fwd] Clean up fwd_combine kernel, compile w cute fake tensors
tridao Mar 11, 2026
6d2ccf2
[Test] Add test_flash_attn_fast.py
tridao Mar 11, 2026
d7fb450
[Fwd,Sm80] Fix import of BlockSparseTensors
tridao Mar 11, 2026
99d0148
[Fwd,Sm90] Tune tile size for hdim 64, 96, 128
tridao Mar 11, 2026
b664ea0
[Bwd,Sm90] Implement deterministic
tridao Mar 11, 2026
4219765
Fix FA2 + FA4 co-existence (#2331)
drisspg Mar 11, 2026
5dcdfcc
[Cute,Sm100] Introduce a flexible lambda-based R2P masking (#2313)
Alkaid-Benetnash Mar 12, 2026
fb0fb1b
[Bwd] Compile bwd_postprocess with cute fake tensors
tridao Mar 11, 2026
161b0e6
fix(sm90): plumb seqused_q/k through SM90 backward pass (#2315)
NJX-njx Mar 12, 2026
fb69aef
SM120 forward pass (Blackwell GeForce / DGX Spark) (#2329)
blake-snc Mar 12, 2026
966c69c
[cutlass] Allow compilation of cutlass FA3 for sm100 via enable_sm90 …
henrylhtsang Mar 12, 2026
064377e
rename logging module (#2335)
Luosuu Mar 12, 2026
b900a5b
Add tile parameter to SeqlenInfo creation (#2337)
risan-raja Mar 12, 2026
bbe25ba
Fix (#2338)
MatthewBonanni Mar 12, 2026
3f94643
[AMD] Migrate to Triton Backend to Aiter (#2230)
micmelesse Mar 12, 2026
5846714
[Bwd,Sm90] Pass tile_m to bwd_preprocess, enable varlen tests
tridao Mar 12, 2026
0687866
[Fwd,Sm90] Use mask_r2p_lamba
tridao Mar 12, 2026
076e004
[Bwd,Sm90] Fix varlen scheduler
tridao Mar 12, 2026
9dfd79e
[Bwd,Sm90] Enable varlen tests with seqused_k
tridao Mar 12, 2026
921f70c
[Bwd,Sm120] Add SM120 backward pass support (#2330)
blake-snc Mar 12, 2026
7cd4fcb
[Bwd,Sm90] Enable local
tridao Mar 12, 2026
1084696
fix tdKrdS typo (#2341)
henrylhtsang Mar 13, 2026
f75386f
[Bwd,Sm90] Implement ShuffleLSE
tridao Mar 13, 2026
a4510a6
[Bwd] Support gradient wrt LSE
tridao Mar 13, 2026
4c8e5ee
Add SM120 varlen attention support (#2333)
blake-snc Mar 13, 2026
6d4b554
[Bwd] Use ragged tensor for TMA dKV when varlen
tridao Mar 13, 2026
3a0d6bb
[Bwd,Sm90] Set bwd configs, make hdim64 bwd work
tridao Mar 13, 2026
ef00082
fix the create_ragged_tensor_for_tma issue (#2345)
rainj-me Mar 14, 2026
3f02df3
[Fwd,Sm90] Implement rescale_O_before_gemm, enable hdim 192 & 256
tridao Mar 14, 2026
ff118ac
[Fwd,Sm90] Add hdim 192 and 256 to _validate_head_dims
tridao Mar 14, 2026
10bbfd0
Support CPU-only compilation and overriding arch
tridao Mar 14, 2026
cea37f5
[Bwd,Sm90] Implement PDL between bwd_preprocess and bwd
tridao Mar 15, 2026
9d242ba
[Bench] Refactor benchmark script to take args from cmdline
tridao Mar 15, 2026
c325b5b
[Bwd,Sm90] Implement dQ_single_wg
tridao Mar 15, 2026
45552b3
[Bwd,Sm90] Make hdim 96 work
tridao Mar 15, 2026
5c7c494
[Sm90] Add script to search fwd bwd configs
tridao Mar 15, 2026
fb5c4fa
[Sm90] Clean up sm90_config_search.py
tridao Mar 15, 2026
71bf77c
[Bwd,Sm90] Implement hdim 192-128 and hdim 192
tridao Mar 15, 2026
ab0efcb
[Sm90] Fix test_mask_mod and bwd block-sparse kwarg mismatch (#2365)
henrylhtsang Mar 17, 2026
3e40b8e
[Cute, Testing] Move stream parameter to end of kernel __call__ signa…
Alkaid-Benetnash Mar 18, 2026
3387de4
[Cute] Bump cutedsl to 4.4.2 and remove prior aot workarounds (#2370)
Alkaid-Benetnash Mar 18, 2026
ce917a6
[Cute] fix: FA4 paged attention kv load for DeepSeek (192,128) on SM1…
Luosuu Mar 18, 2026
8afc617
[AMD ROCm] Update ROCm/CK backend to align with latest ComposableKern…
rocking5566 Mar 18, 2026
0f82fea
[ROCm] Auto-detect Triton backend if C++ extension is missing (#2343)
Soddentrough Mar 18, 2026
0372975
[Fwd,Sm90] Add paged KV attention support (tma and cp.async) (#2360)
henrylhtsang Mar 18, 2026
e4b46ed
[Tool] Add sass_diff
tridao Mar 18, 2026
9397816
[CuTe,Flex] limit vec_size to 2 for score mod when not on Sm100 (#2371)
reubenconducts Mar 18, 2026
951f398
[Fwd,Sm90] Use pipeline_q instead of raw mbarrier
tridao Mar 19, 2026
5f68230
[Fwd,Sm90] Use TMA for O when PackGQA, keep no. TMA dim when PackGQA
tridao Mar 19, 2026
07bd3af
Support 2CTA for sliding window hdim 192 (#2347)
Inodayy Mar 19, 2026
dd6f4a2
[Fwd,Sm90] Use TMA for O when varlen
tridao Mar 19, 2026
3250081
support irregular q to kv head ratio (#2186)
timmy-feng Mar 20, 2026
cc8cb90
[Pipeline] Refactor
tridao Mar 20, 2026
9f919c6
[Fwd,Sm90] Use producer instead of mma warps to load Q when !TMA_Q
tridao Mar 20, 2026
6317062
[Fwd,Sm90] Implement PipelineAsync with elect_one for commit/release
tridao Mar 20, 2026
3cafddf
benchmarks: add MFU% column with dtype-aware peak FLOPS (#2377)
Johnsonms Mar 22, 2026
e1d2e07
[Bench] Clean up peak flops numbers
tridao Mar 22, 2026
ef736fe
[DSL] Remove ArgumentsBase
tridao Mar 22, 2026
6362bd3
Update flow to enable beta weekly releases (#2378)
drisspg Mar 23, 2026
e065bf1
Recommend cu13 extra for best perf
tridao Mar 25, 2026
28ef22c
fix: use LSE accum strides from params instead of hardcoded ones (#2388)
ZeronSix Mar 25, 2026
b8eda39
Add README link for Turing support (#2379)
ssiu Mar 25, 2026
5c7711e
refine bwd swizzle when deterministic (#2390)
jayhshah Mar 25, 2026
b2176fd
[Fwd,Sm100] Tune ex2 frequency and registers
tridao Mar 25, 2026
3c20009
[Fwd,Sm100] Enable 2CTA for hdim 192-128 noncausal
tridao Mar 25, 2026
abd9943
Fix edge case when tag has no delta from previous (#2394)
drisspg Mar 25, 2026
4fcfdec
[Fwd,Sm100] Clean up pipeline creation a bit
tridao Mar 26, 2026
5301a35
[AMD ROCm] Update CK and add RDNA 3/4 support (#2400)
rocking5566 Mar 26, 2026
98024f9
[Ai-assisted] CLC work stealing (#2218)
drisspg Mar 28, 2026
66bedce
Various bug fixes / enable subtile > 2 (#2411)
drisspg Mar 30, 2026
29e40cf
Add to varlen (#2346)
drisspg Mar 30, 2026
f6a16e1
Allow compact block sparse index tensors (#2417)
jduprat Apr 1, 2026
c86d26e
Merge remote-tracking branch 'upstream/main' into sync-upstream-cute
MatthewBonanni Apr 1, 2026
db466bd
Merge remote-tracking branch 'upstream/main' into sync-upstream-cute
MatthewBonanni Apr 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions flash_attn/cute/AUTHORS
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
Tri Dao, tri@tridao.me
Tri Dao
Jay Shah
Ted Zadouri
Markus Hoehnerbach
Vijay Thakkar
Vijay Thakkar
Timmy Liu
Driss Guessous
Reuben Stern
5 changes: 5 additions & 0 deletions flash_attn/cute/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
global-exclude *.egg-info/*
prune flash_attn_4.egg-info
prune flash_attn.egg-info
prune build
prune dist
40 changes: 23 additions & 17 deletions flash_attn/cute/README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
# Flash Attention CUTE
# FlashAttention-4 (CuTeDSL)

## Development Installation
FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs.

1. Clone the repository (if you haven't already):
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/cute
```
## Installation

2. Install in editable mode with dev dependencies:
```bash
pip install -e "./cute[dev]"
```
```sh
pip install flash-attn-4
```

## Running Tests
If you're on CUDA 13, install with the `cu13` extra for best performance:

```bash
pytest tests/cute/
```sh
pip install "flash-attn-4[cu13]"
```

## Linting
## Usage

```bash
ruff check flash_attn/cute/
```python
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func

out = flash_attn_func(q, k, v, causal=True)
```

## Development

```sh
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install -e "flash_attn/cute[dev]"
pytest tests/cute/
```
7 changes: 6 additions & 1 deletion flash_attn/cute/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Flash Attention CUTE (CUDA Template Engine) implementation."""

__version__ = "0.1.0"
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("fa4")
except PackageNotFoundError:
__version__ = "0.0.0"

import cutlass.cute as cute

Expand Down
196 changes: 196 additions & 0 deletions flash_attn/cute/bench_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Shared benchmark utilities: attention_ref, cuDNN helpers, flops calculation."""

import math
import torch

try:
import cudnn
except ImportError:
cudnn = None


# ── FLOPS calculation ────────────────────────────────────────────────────────


def flops(
batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)
):
if causal:
avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
else:
if window_size == (None, None):
avg_seqlen = seqlen_k
else:
row_idx = torch.arange(seqlen_q, device="cuda")
col_left = (
torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
if window_size[0] is not None
else torch.zeros_like(row_idx)
)
col_right = (
torch.minimum(
row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)
)
if window_size[1] is not None
else torch.full_like(row_idx, seqlen_k - 1)
)
avg_seqlen = (col_right - col_left + 1).float().mean().item()
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)


# ── Reference attention ─────────────────────────────────────────────────────

_attention_ref_mask_cache = {}


def attention_ref(q, k, v, causal=False):
"""Standard attention reference implementation.

Args:
q, k, v: (batch, seqlen, nheads, headdim) tensors.
causal: whether to apply causal mask.
"""
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
if causal:
if scores.shape[-2] not in _attention_ref_mask_cache:
mask = torch.tril(
torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0
)
_attention_ref_mask_cache[scores.shape[-2]] = mask
else:
mask = _attention_ref_mask_cache[scores.shape[-2]]
scores = scores.masked_fill(mask, float("-inf"))
attn = torch.softmax(scores, dim=-1)
return torch.einsum("bhts,bshd->bthd", attn, v)


# ── cuDNN graph helpers ─────────────────────────────────────────────────────

_TORCH_TO_CUDNN_DTYPE = {
torch.float16: "HALF",
torch.bfloat16: "BFLOAT16",
torch.float32: "FLOAT",
torch.int32: "INT32",
torch.int64: "INT64",
}


def _build_cudnn_graph(io_dtype, tensors, build_fn):
"""Build a cuDNN graph. Returns (graph, variant_pack, workspace)."""
assert cudnn is not None, "cuDNN is not available"
cudnn_dtype = getattr(cudnn.data_type, _TORCH_TO_CUDNN_DTYPE[io_dtype])
graph = cudnn.pygraph(
io_data_type=cudnn_dtype,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
graph_tensors = {name: graph.tensor_like(t.detach()) for name, t in tensors.items()}
variant_pack = build_fn(graph, graph_tensors)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
return graph, variant_pack, workspace


def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None):
"""Build a cuDNN forward SDPA graph.

Args:
q, k, v: (batch, nheads, seqlen, headdim) tensors (cuDNN layout).
causal: whether to apply causal mask.
window_size_left: sliding window size (None for no window).

Returns:
(fwd_fn, o_gpu, stats_gpu) where fwd_fn is a zero-arg callable.
"""
b, nheads, seqlen_q, headdim = q.shape
headdim_v = v.shape[-1]
o_gpu = torch.empty(b, nheads, seqlen_q, headdim_v, dtype=q.dtype, device=q.device)
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)

def build(graph, gt):
o, stats = graph.sdpa(
name="sdpa",
q=gt["q"],
k=gt["k"],
v=gt["v"],
is_inference=False,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal or window_size_left is not None,
sliding_window_length=window_size_left
if window_size_left is not None and not causal
else None,
)
o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
return {gt["q"]: q, gt["k"]: k, gt["v"]: v, o: o_gpu, stats: stats_gpu}

graph, variant_pack, workspace = _build_cudnn_graph(q.dtype, {"q": q, "k": k, "v": v}, build)

def fwd_fn():
graph.execute(variant_pack, workspace)
return o_gpu

return fwd_fn, o_gpu, stats_gpu


def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None):
"""Build a cuDNN backward SDPA graph.

Args:
q, k, v, o, g, lse: (batch, nheads, seqlen, dim) tensors (cuDNN layout).
causal: whether to apply causal mask.
window_size_left: sliding window size (None for no window).

Returns:
bwd_fn: zero-arg callable that returns (dq, dk, dv).
"""
headdim = q.shape[-1]
dq_gpu, dk_gpu, dv_gpu = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)

def build(graph, gt):
dq, dk, dv = graph.sdpa_backward(
name="sdpa_backward",
q=gt["q"],
k=gt["k"],
v=gt["v"],
o=gt["o"],
dO=gt["g"],
stats=gt["lse"],
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal or window_size_left is not None,
sliding_window_length=window_size_left
if window_size_left is not None and not causal
else None,
use_deterministic_algorithm=False,
)
dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())
dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())
dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())
return {
gt["q"]: q,
gt["k"]: k,
gt["v"]: v,
gt["o"]: o,
gt["g"]: g,
gt["lse"]: lse,
dq: dq_gpu,
dk: dk_gpu,
dv: dv_gpu,
}

graph, variant_pack, workspace = _build_cudnn_graph(
q.dtype,
{"q": q, "k": k, "v": v, "o": o, "g": g, "lse": lse},
build,
)

def bwd_fn():
graph.execute(variant_pack, workspace)
return dq_gpu, dk_gpu, dv_gpu

return bwd_fn
Loading
Loading