diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index 69ef3c77619..777d3613eb1 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -18,7 +18,12 @@ get_aux_tensor_metadata, to_cute_aux_tensor, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.utils import ( + hash_callable, + scalar_to_ssa, + ssa_to_scalar, + get_batch_from_cu_tensor, +) from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -161,16 +166,7 @@ class SharedStorage: m_block, head_idx, batch_idx = cute.arch.block_idx() else: global_m_block, head_idx, _ = cute.arch.block_idx() - # Binary search over cu_total_m_blocks to find batch_idx - lo = Int32(0) - hi = batch_size - while lo < hi: - mid = (lo + hi) // 2 - if mCuTotalMBlocks[mid + 1] <= global_m_block: - lo = mid + 1 - else: - hi = mid - batch_idx = lo + batch_idx = get_batch_from_cu_tensor(global_m_block, mCuTotalMBlocks) m_block = global_m_block - mCuTotalMBlocks[batch_idx] seqlen = SeqlenInfoCls(batch_idx) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index ffabd34e398..c8398c9a78d 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -949,3 +949,20 @@ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: def ssa_to_scalar(val): """Could inline but nice for reflecting the above api""" return val[0] + + +@cute.jit +def get_batch_from_cu_tensor(idx: Int32, cu_tensor: cute.Tensor) -> Int32: + """Binary search to determine batch from packed index in a cumulative tensor""" + batch_size = cute.size(cu_tensor) - 1 + lo = Int32(0) + hi = batch_size + + while lo < hi: + mid = (lo + hi) // 2 + if cu_tensor[mid + 1] <= idx: + lo = mid + 1 + else: + hi = mid + + return lo