Skip to content

Improve flash.cute paged_kv cpasync#2156

Merged
v0i0 merged 2 commits intoDao-AILab:mainfrom
v0i0:v0i0/improve-paged-ldgsts
Jan 12, 2026
Merged

Improve flash.cute paged_kv cpasync#2156
v0i0 merged 2 commits intoDao-AILab:mainfrom
v0i0:v0i0/improve-paged-ldgsts

Conversation

@v0i0
Copy link
Collaborator

@v0i0 v0i0 commented Jan 8, 2026

No description provided.

@v0i0 v0i0 requested review from drisspg and jayhshah January 8, 2026 22:24
should_load = tXcX[0, m, 0][0] < seqlenk_row_limit
for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit
should_load = cute.make_fragment_like(tXsX[None, m, 0], cute.Boolean)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make_rmem_tensor_like

Copy link
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
CC @jayhshah lll let you take a look

@reubenconducts
Copy link
Contributor

reubenconducts commented Jan 9, 2026

Some very quick tests show ever so slightly worse perf than #2104; have you measured? that is, adding your changes onto #2104.

@v0i0
Copy link
Collaborator Author

v0i0 commented Jan 9, 2026

Some very quick tests show ever so slightly worse perf than #2104; have you measured? that is, adding your changes onto #2104.

first commit in the PR has the benchmark script I used

@reubenconducts
Copy link
Contributor

Some very quick tests show ever so slightly worse perf than #2104; have you measured? that is, adding your changes onto #2104.

first commit in the PR has the benchmark script I used

Thanks! I stand corrected, the combination of this PR and 2104 appears to almost close the gap between page size < 128 and = 128:

COMBINED PRS
====================================================================================================
PAGED ATTENTION BENCHMARK
====================================================================================================
Page sizes: [1, 4, 8, 16, 32, 64, 128]
Head dimensions: [128]
Batch sizes: [4]
Sequence lengths: [65536]
Causal: True, dtype: torch.bfloat16
Testing fragmented page tables: False
====================================================================================================

### headdim=128, batch=4, seqlen=65536 ###
  Baseline (no paging): 0.653ms, 6.6 TFLOPS
  page_size=  1 (contiguous): 0.731ms, 5.8750 TFLOPS, overhead: +11.9%
  page_size=  4 (contiguous): 0.721ms, 5.9529 TFLOPS, overhead: +10.4%
  page_size=  8 (contiguous): 0.720ms, 5.9626 TFLOPS, overhead: +10.2%
  page_size= 16 (contiguous): 0.720ms, 5.9638 TFLOPS, overhead: +10.2%
  page_size= 32 (contiguous): 0.728ms, 5.8962 TFLOPS, overhead: +11.5%
  page_size= 64 (contiguous): 0.753ms, 5.7043 TFLOPS, overhead: +15.2%
  page_size=128 (contiguous): 0.656ms, 6.5462 TFLOPS, overhead: +0.4%

@jayhshah
Copy link
Collaborator

jayhshah commented Jan 9, 2026

It's interesting that you use a predicate of size (atom_v, rest_v) instead of (rest_v) but this still generates the right LDGSTS instructions. If I switch over to

should_load = cute.make_fragment_like(tXsX[(0, None), None, 0], cute.Boolean)
for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
    row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit
    should_load[None, m].fill(row_valid)

and also set the gmem layout correctly via

mX_paged_cur_copy_cur = mX_paged_cur_copy[None, ki]
tXsX_cur = tXsX[None, m, k]
mX_paged_cur_copy_cur = cute.make_tensor(mX_paged_cur_copy_cur.iterator, tXsX_cur.layout)

then use

cute.copy(
    self.gmem_tiled_copy_KV,
    mX_paged_cur_copy_cur,
    tXsX_cur,
    pred=should_load[None, m],
)

I can sometimes get better results, e.g.

NOT CAUSAL
### headdim=128, batch=4, seqlen=8192 ###
  Baseline (no paging): 1.454ms, 1512.8 TFLOPS
  OLD page_size= 64 (contiguous): 3.819ms, 575.8 TFLOPS, overhead: +162.7%
  NEW page_size= 64 (contiguous): 3.418ms, 643.3 TFLOPS, overhead: +135.3%

CAUSAL
### headdim=128, batch=4, seqlen=8192 ###
  Baseline (no paging): 0.780ms, 1409.4 TFLOPS
  OLD page_size= 64 (contiguous): 1.834ms, 599.6 TFLOPS, overhead: +134.8%
  NEW page_size= 64 (contiguous): 1.793ms, 613.3 TFLOPS, overhead: +129.8%

But this prefetch computation for the predicate sometimes hurts perf as well like for varlen. We can merge this PR and revisit once the distributed offset PR is also merged in.

vikhyat added a commit to m87-labs/kestrel that referenced this pull request Jan 9, 2026
Replace if/elif branching with predicated cp.async for paged KV loading.
This simplifies the code by removing the fill_swizzled helper and using
a single loop with pred= parameter.

Based on: Dao-AILab/flash-attention#2156
@v0i0
Copy link
Collaborator Author

v0i0 commented Jan 12, 2026

@jayhshah took a quick look & it seemed like the reduced predicate tensor increases spilling (which in general seems to be an issue in this code). So gonna merge as is.

image

@v0i0 v0i0 merged commit dbf08eb into Dao-AILab:main Jan 12, 2026
elewarr pushed a commit to elewarr/flash-attention that referenced this pull request Feb 4, 2026
YangWang92 pushed a commit to YangWang92/flash-attention that referenced this pull request Feb 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants