Improve flash.cute paged_kv cpasync#2156
Conversation
| should_load = tXcX[0, m, 0][0] < seqlenk_row_limit | ||
| for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])): | ||
| row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit | ||
| should_load = cute.make_fragment_like(tXsX[None, m, 0], cute.Boolean) |
There was a problem hiding this comment.
nit: make_rmem_tensor_like
Thanks! I stand corrected, the combination of this PR and 2104 appears to almost close the gap between page size < 128 and = 128: |
|
It's interesting that you use a predicate of size and also set the gmem layout correctly via then use I can sometimes get better results, e.g. But this prefetch computation for the predicate sometimes hurts perf as well like for varlen. We can merge this PR and revisit once the distributed offset PR is also merged in. |
Replace if/elif branching with predicated cp.async for paged KV loading. This simplifies the code by removing the fill_swizzled helper and using a single loop with pred= parameter. Based on: Dao-AILab/flash-attention#2156
|
@jayhshah took a quick look & it seemed like the reduced predicate tensor increases spilling (which in general seems to be an issue in this code). So gonna merge as is.
|
Improve flash.cute paged_kv cpasync
Improve flash.cute paged_kv cpasync

No description provided.