From 201f2a11368596cd94aaebb5341da4f8c76d7bf3 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Tue, 12 May 2026 19:05:23 +0000 Subject: [PATCH 1/2] split out varlen batch search into utils --- flash_attn/cute/compute_block_sparsity.py | 13 ++----------- flash_attn/cute/utils.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index 69ef3c77619..aaf5208c8c0 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -18,7 +18,7 @@ get_aux_tensor_metadata, to_cute_aux_tensor, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar, batch_search from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -161,16 +161,7 @@ class SharedStorage: m_block, head_idx, batch_idx = cute.arch.block_idx() else: global_m_block, head_idx, _ = cute.arch.block_idx() - # Binary search over cu_total_m_blocks to find batch_idx - lo = Int32(0) - hi = batch_size - while lo < hi: - mid = (lo + hi) // 2 - if mCuTotalMBlocks[mid + 1] <= global_m_block: - lo = mid + 1 - else: - hi = mid - batch_idx = lo + batch_idx = batch_search(global_m_block, mCuTotalMBlocks) m_block = global_m_block - mCuTotalMBlocks[batch_idx] seqlen = SeqlenInfoCls(batch_idx) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index ffabd34e398..87045096edd 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -949,3 +949,20 @@ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: def ssa_to_scalar(val): """Could inline but nice for reflecting the above api""" return val[0] + + +@cute.jit +def batch_search(token: Int32, cu_seqlens: cute.Tensor) -> Int32: + """Binary search to determine batch from packed token""" + batch_size = cute.size(cu_seqlens) - 1 + lo = Int32(0) + hi = batch_size + + while lo < hi: + mid = (lo + hi) // 2 + if cu_seqlens[mid + 1] <= token: + lo = mid + 1 + else: + hi = mid + + return lo From 62f107cb95a26cc16ae444056491ec6265e3bc8b Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Thu, 14 May 2026 17:20:01 +0000 Subject: [PATCH 2/2] more descriptive name --- flash_attn/cute/compute_block_sparsity.py | 9 +++++++-- flash_attn/cute/utils.py | 8 ++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index aaf5208c8c0..777d3613eb1 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -18,7 +18,12 @@ get_aux_tensor_metadata, to_cute_aux_tensor, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar, batch_search +from flash_attn.cute.utils import ( + hash_callable, + scalar_to_ssa, + ssa_to_scalar, + get_batch_from_cu_tensor, +) from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -161,7 +166,7 @@ class SharedStorage: m_block, head_idx, batch_idx = cute.arch.block_idx() else: global_m_block, head_idx, _ = cute.arch.block_idx() - batch_idx = batch_search(global_m_block, mCuTotalMBlocks) + batch_idx = get_batch_from_cu_tensor(global_m_block, mCuTotalMBlocks) m_block = global_m_block - mCuTotalMBlocks[batch_idx] seqlen = SeqlenInfoCls(batch_idx) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 87045096edd..c8398c9a78d 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -952,15 +952,15 @@ def ssa_to_scalar(val): @cute.jit -def batch_search(token: Int32, cu_seqlens: cute.Tensor) -> Int32: - """Binary search to determine batch from packed token""" - batch_size = cute.size(cu_seqlens) - 1 +def get_batch_from_cu_tensor(idx: Int32, cu_tensor: cute.Tensor) -> Int32: + """Binary search to determine batch from packed index in a cumulative tensor""" + batch_size = cute.size(cu_tensor) - 1 lo = Int32(0) hi = batch_size while lo < hi: mid = (lo + hi) // 2 - if cu_seqlens[mid + 1] <= token: + if cu_tensor[mid + 1] <= idx: lo = mid + 1 else: hi = mid