Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 61 additions & 34 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,20 @@ def consume_block_sparse_loads(
return kv_consumer_state, O_should_accumulate, processed_any


@cute.jit
def split_block_range(block_count, split_idx: Int32, num_splits: Int32):
"""Return the half-open block-list range assigned to one SplitKV partition."""
blocks_per_split = cute.ceil_div(block_count, num_splits)
block_begin = cutlass.min(split_idx * blocks_per_split, block_count)
block_end = cutlass.min(block_begin + blocks_per_split, block_count)
return block_begin, block_end


@cute.jit
def load_block_list_sm100(
block_indices: cute.Tensor,
block_count,
block_begin,
block_end,
load_q_with_first: cutlass.Constexpr,
q_stage: cutlass.Constexpr,
kv_producer_state,
Expand All @@ -563,9 +573,10 @@ def load_block_list_sm100(
pipeline_kv,
):
"""SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count)."""
block_count = block_end - block_begin
if block_count > 0:
# First iteration: load Q alongside K if requested
n_block_first = block_indices[block_count - 1]
n_block_first = block_indices[block_end - 1]

if const_expr(load_q_with_first):
# SM100 loads Q0 and optionally Q1
Expand All @@ -582,7 +593,7 @@ def load_block_list_sm100(

# Remaining blocks
for offset in cutlass.range(1, block_count):
n_block = block_indices[block_count - 1 - offset]
n_block = block_indices[block_end - 1 - offset]
load_K(block=n_block, producer_state=kv_producer_state, page_idx=None)
kv_producer_state.advance()
load_V(block=n_block, producer_state=kv_producer_state, page_idx=None)
Expand All @@ -598,7 +609,9 @@ def produce_block_sparse_loads_sm100(
batch_idx,
head_idx,
m_block,
seqlen_info,
seqlen_info: SeqlenInfoQK,
split_idx: Int32,
num_splits: Int32,
kv_producer_state,
load_Q,
load_K,
Expand Down Expand Up @@ -633,16 +646,19 @@ def produce_block_sparse_loads_sm100(
seqlen_info,
)

mask_empty = curr_mask_block_cnt == 0
full_empty = curr_full_block_cnt == 0
mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits)
full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits)
mask_empty = mask_begin == mask_end
full_empty = full_begin == full_end

q_phase_flipped = False

if mask_empty:
# No masked blocks: process full list with Q loading
kv_producer_state = load_block_list_sm100(
curr_full_block_idx,
curr_full_block_cnt,
full_begin,
full_end,
load_q_with_first=True,
q_stage=q_stage,
kv_producer_state=kv_producer_state,
Expand All @@ -656,7 +672,8 @@ def produce_block_sparse_loads_sm100(
# Process masked blocks with Q loading
kv_producer_state = load_block_list_sm100(
curr_mask_block_idx,
curr_mask_block_cnt,
mask_begin,
mask_end,
load_q_with_first=True,
q_stage=q_stage,
kv_producer_state=kv_producer_state,
Expand All @@ -671,7 +688,8 @@ def produce_block_sparse_loads_sm100(
# Process full blocks without Q loading
kv_producer_state = load_block_list_sm100(
curr_full_block_idx,
curr_full_block_cnt,
full_begin,
full_end,
load_q_with_first=False,
q_stage=q_stage,
kv_producer_state=kv_producer_state,
Expand All @@ -693,26 +711,29 @@ def get_total_block_count(
batch_idx,
head_idx,
m_block,
split_idx: Int32,
num_splits: Int32,
qhead_per_kvhead: cutlass.Constexpr,
q_subtile_factor: cutlass.Constexpr,
seqlen_info: SeqlenInfoQK,
):
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
mask_block_cnt, _, full_block_cnt, *_ = blocksparse_tensors

if const_expr(len(mask_block_cnt.shape) == 2):
# varlen path: tensors are [num_heads, total_m_block]
curr_m = seqlen_info.m_block_offset + m_block_sparse
total = mask_block_cnt[head_idx, curr_m]
if const_expr(full_block_cnt is not None):
total += full_block_cnt[head_idx, curr_m]
else:
# non-varlen: tensors are [batch, num_heads, m_block]
total = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
if const_expr(full_block_cnt is not None):
total += full_block_cnt[batch_idx, head_idx, m_block_sparse]
(
curr_mask_block_cnt,
_,
curr_full_block_cnt,
_,
) = get_curr_blocksparse_tensors(
batch_idx,
head_idx,
m_block_sparse,
blocksparse_tensors,
seqlen_info,
)

return total
mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits)
full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits)
return mask_end - mask_begin + full_end - full_begin


@cute.jit
Expand Down Expand Up @@ -839,7 +860,9 @@ def softmax_block_sparse_sm100(
batch_idx,
head_idx,
m_block,
seqlen_info,
seqlen_info: SeqlenInfoQK,
split_idx: Int32,
num_splits: Int32,
softmax_step: Callable,
mask_fn: Callable,
mask_fn_none: Callable,
Expand Down Expand Up @@ -870,13 +893,17 @@ def softmax_block_sparse_sm100(
seqlen_info,
)

total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt
mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits)
full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits)
split_mask_block_cnt = mask_end - mask_begin
split_full_block_cnt = full_end - full_begin
total_block_cnt = split_mask_block_cnt + split_full_block_cnt

if total_block_cnt == 0:
sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx)
else:
if curr_mask_block_cnt > 0:
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
if split_mask_block_cnt > 0:
mask_n_block = curr_mask_block_idx[mask_end - 1]
(
mma_si_consumer_phase,
si_corr_producer_phase,
Expand All @@ -889,8 +916,8 @@ def softmax_block_sparse_sm100(
is_first=True,
mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary),
)
for i in cutlass.range(1, curr_mask_block_cnt):
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
for i in cutlass.range(1, split_mask_block_cnt):
mask_n_block = curr_mask_block_idx[mask_end - 1 - i]
(
mma_si_consumer_phase,
si_corr_producer_phase,
Expand All @@ -903,9 +930,9 @@ def softmax_block_sparse_sm100(
mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary),
)

if curr_full_block_cnt > 0:
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
if curr_mask_block_cnt == 0:
if split_full_block_cnt > 0:
full_n_block = curr_full_block_idx[full_end - 1]
if split_mask_block_cnt == 0:
(
mma_si_consumer_phase,
si_corr_producer_phase,
Expand Down Expand Up @@ -935,8 +962,8 @@ def softmax_block_sparse_sm100(
mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary
),
)
for i in cutlass.range(1, curr_full_block_cnt):
full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
for i in cutlass.range(1, split_full_block_cnt):
full_n_block = curr_full_block_idx[full_end - 1 - i]
(
mma_si_consumer_phase,
si_corr_producer_phase,
Expand Down
15 changes: 14 additions & 1 deletion flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,7 @@ def kernel(
num_splits,
SeqlenInfoCls,
mma_tile_coord_v,
blocksparse_tensors=blocksparse_tensors,
tile_scheduler=tile_scheduler,
)

Expand Down Expand Up @@ -1499,6 +1500,8 @@ def load(
head_idx,
m_block,
seqlen,
split_idx,
num_splits,
kv_producer_state,
load_Q,
load_K,
Expand Down Expand Up @@ -1646,6 +1649,8 @@ def mma(
batch_idx,
head_idx,
m_block,
split_idx,
num_splits,
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
seqlen_info=seqlen,
Expand Down Expand Up @@ -2013,6 +2018,8 @@ def softmax_loop(
batch_idx,
head_idx,
m_block,
split_idx,
num_splits,
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
seqlen_info=seqlen,
Expand Down Expand Up @@ -2072,6 +2079,8 @@ def softmax_loop(
head_idx,
m_block,
seqlen,
split_idx,
num_splits,
softmax_step,
mask_fn,
mask_fn_none,
Expand Down Expand Up @@ -2427,6 +2436,8 @@ def correction_loop(
batch_idx,
head_idx,
m_block,
split_idx,
num_splits,
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
self.q_subtile_factor if self.q_subtile_factor is not None else 1,
seqlen_info=seqlen,
Expand Down Expand Up @@ -2841,6 +2852,7 @@ def epilogue_s2g(
num_splits: int,
SeqlenInfoCls: Callable,
mma_tile_coord_v: Int32 = 0,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
tile_scheduler=None,
):
epi_consumer_phase = Int32(0)
Expand All @@ -2849,8 +2861,9 @@ def epilogue_s2g(
m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
seqlen = SeqlenInfoCls(batch_idx)
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
has_work = const_expr(self.use_block_sparsity or not self.is_split_kv) or n_block_min < n_block_max

if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
if has_work:
if const_expr(self.is_split_kv):
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
else:
Expand Down
4 changes: 0 additions & 4 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,6 @@ def _flash_attn_fwd(
head_dim_idx = 0 if block_sparse_tensors.mask_block_cnt.ndim == 2 else 1
if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[head_dim_idx] != 1:
pack_gqa = False
if is_split_kv:
raise NotImplementedError(
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
)
if cu_seqlens_q is not None:
assert block_sparse_tensors.cu_total_m_blocks is not None, (
"Varlen block sparsity requires block_sparse_tensors.cu_total_m_blocks."
Expand Down
Loading