Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions flash_attn/cute/block_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,20 @@ def get_n_block_min_before_local_mask(
n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
n_idx_left = n_idx - self.window_size_left
return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n))

@cute.jit
def get_n_block_max_for_m_block(
self,
seqlen_info: SeqlenInfoQK,
m_block: Int32,
n_block_global_max: Int32,
) -> Int32:
if const_expr(self.is_causal or self.window_size_right is not None):
m_idx_max = (m_block + 1) * self.tile_m
if const_expr(self.qhead_per_kvhead_packgqa > 1):
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
if const_expr(self.window_size_right is not None):
n_idx_right += self.window_size_right
return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n))
return n_block_global_max
22 changes: 10 additions & 12 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def produce_block_sparse_loads(
must be converted to unpacked for sparse tensor indexing.
"""

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors

m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)

Expand Down Expand Up @@ -332,7 +332,7 @@ def consume_block_sparse_loads(
must be converted to unpacked for sparse tensor indexing.
"""

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors

m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)

Expand Down Expand Up @@ -552,7 +552,7 @@ def produce_block_sparse_loads_sm100(
"""
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
Expand Down Expand Up @@ -629,7 +629,7 @@ def get_total_block_count(
):
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors
if const_expr(full_block_cnt is not None):
return (
mask_block_cnt[batch_idx, head_idx, m_block_sparse]
Expand Down Expand Up @@ -780,7 +780,7 @@ def softmax_block_sparse_sm100(
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
Expand All @@ -795,8 +795,6 @@ def softmax_block_sparse_sm100(
total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt

if total_block_cnt == 0:
# See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
# pipeline_sm_stats.producer_commit_w_index(stage_idx)
sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx)
else:
if curr_mask_block_cnt > 0:
Expand Down Expand Up @@ -907,7 +905,7 @@ def get_total_q_block_count_bwd(
m_block_max: int = 0,
):
"""Count total tile iterations for given n_block (KV tile) in backward."""
q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors
q_block_cnt, _, full_block_cnt, _, *_ = blocksparse_tensors
total = q_block_cnt[batch_idx, head_idx, n_block]
if const_expr(full_block_cnt is not None):
total = total + full_block_cnt[batch_idx, head_idx, n_block]
Expand Down Expand Up @@ -1051,7 +1049,7 @@ def get_block_sparse_iteration_info_bwd(

Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count).
"""
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]

Expand Down Expand Up @@ -1175,7 +1173,7 @@ def produce_block_sparse_q_loads_bwd_sm90(

Returns updated (producer_state_Q, producer_state_dO).
"""
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]

Expand Down Expand Up @@ -1270,7 +1268,7 @@ def consume_block_sparse_mma_bwd_sm90(

Returns updated (consumer_state_Q, consumer_state_dO).
"""
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]

Expand Down Expand Up @@ -1396,7 +1394,7 @@ def dQaccum_store_block_sparse_bwd_sm90(

Iterates partial blocks first, then full blocks, matching producer/consumer order.
"""
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]

Expand Down
213 changes: 194 additions & 19 deletions flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ class BlockSparseTensors(NamedTuple):
mask_block_idx: cute.Tensor
full_block_cnt: cute.Tensor | None
full_block_idx: cute.Tensor | None
dq_write_order: cute.Tensor | None = None
dq_write_order_full: cute.Tensor | None = None

def __new_from_mlir_values__(self, values):
if len(values) == 2:
values = (*values, None, None, None, None)
elif len(values) == 4:
values = (*values, None, None)
return BlockSparseTensors(*values)

Expand All @@ -32,6 +36,138 @@ class BlockSparseTensorsTorch(NamedTuple):
full_block_cnt: torch.Tensor | None = None
full_block_idx: torch.Tensor | None = None
block_size: tuple[int, int] | None = None
dq_write_order: torch.Tensor | None = None
dq_write_order_full: torch.Tensor | None = None
spt: bool | None = None


def _ordered_to_dense_simple(

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here

num_blocks: torch.Tensor,
indices: torch.Tensor,
num_cols: int,
) -> torch.Tensor:
"""Convert ordered sparse representation to dense binary matrix.

Args:
num_blocks: [B, H, num_rows] count of valid entries per row
indices: [B, H, num_rows, max_entries] column indices (valid entries packed left)
num_cols: total number of columns

Returns:
dense: [B, H, num_rows, num_cols] binary int32 matrix
"""
B, H, num_rows, max_entries = indices.shape
device = indices.device
dense = torch.zeros(B, H, num_rows, num_cols + 1, dtype=torch.int32, device=device)
col_range = torch.arange(max_entries, device=device)
valid = col_range[None, None, None, :] < num_blocks[:, :, :, None]
safe_indices = torch.where(valid, indices.long(), num_cols)
row_idx = torch.arange(num_rows, device=device)[None, None, :, None].expand_as(indices)
b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(indices)
h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(indices)
dense[b_idx, h_idx, row_idx, safe_indices] = 1
return dense[:, :, :, :num_cols]


def compute_dq_write_order(

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can move this to PT call after we land there after this 😂

fwd_mask_cnt: torch.Tensor,
fwd_mask_idx: torch.Tensor,
fwd_full_cnt: torch.Tensor | None,
fwd_full_idx: torch.Tensor | None,
bwd_mask_cnt: torch.Tensor,
bwd_mask_idx: torch.Tensor,
bwd_full_cnt: torch.Tensor | None,
bwd_full_idx: torch.Tensor | None,
spt: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Compute dQ write-order metadata for deterministic block-sparse backward.

For each (n_block, i) in the backward iteration, computes the semaphore
lock value: the rank of n_block in the combined (partial + full) sorted
contributor list for the target m_block.

Lock values are assigned in ascending n_block order (or descending if spt=True)
to guarantee deadlock-freedom with the CTA scheduling order.

Args:
fwd_mask_cnt: [B, H, num_m_blocks] partial contributor counts per m_block
fwd_mask_idx: [B, H, num_m_blocks, max_kv] partial contributor n_block indices (ascending)
fwd_full_cnt: [B, H, num_m_blocks] full contributor counts per m_block (optional)
fwd_full_idx: [B, H, num_m_blocks, max_kv] full contributor n_block indices (optional)
bwd_mask_cnt: [B, H, num_n_blocks] partial iteration counts per n_block
bwd_mask_idx: [B, H, num_n_blocks, max_q] partial iteration m_block indices
bwd_full_cnt: [B, H, num_n_blocks] full iteration counts per n_block (optional)
bwd_full_idx: [B, H, num_n_blocks, max_q] full iteration m_block indices (optional)
spt: if True, reverse ordering (highest n_block gets lock_value=0)

Returns:
(dq_write_order, dq_write_order_full): tensors parallel to bwd_mask_idx
and bwd_full_idx respectively, containing lock values.
"""
device = fwd_mask_idx.device
B, H, num_m, max_kv_partial = fwd_mask_idx.shape
_, _, num_n, max_q_partial = bwd_mask_idx.shape

has_full = fwd_full_cnt is not None and fwd_full_idx is not None

dense_partial = _ordered_to_dense_simple(fwd_mask_cnt, fwd_mask_idx, num_n)
if has_full:
dense_full = _ordered_to_dense_simple(fwd_full_cnt, fwd_full_idx, num_n)
dense = (dense_partial + dense_full).clamp(max=1)
else:
dense = dense_partial

cumsum = dense.cumsum(dim=-1)
rank_table = (cumsum - dense).to(torch.int32)

if spt:
total_per_m = cumsum[:, :, :, -1:]
rank_table = (total_per_m - 1 - rank_table).to(torch.int32)

def _gather_write_order(bwd_idx, bwd_cnt):
b_i = torch.arange(B, device=device)[:, None, None, None].expand_as(bwd_idx)
h_i = torch.arange(H, device=device)[None, :, None, None].expand_as(bwd_idx)
n_i = torch.arange(bwd_idx.shape[2], device=device)[None, None, :, None].expand_as(bwd_idx)
m_vals = bwd_idx.long().clamp(0, num_m - 1)
return rank_table[b_i, h_i, m_vals, n_i].to(torch.int32)

dq_write_order = _gather_write_order(bwd_mask_idx, bwd_mask_cnt)

dq_write_order_full = None
if has_full and bwd_full_cnt is not None and bwd_full_idx is not None:
dq_write_order_full = _gather_write_order(bwd_full_idx, bwd_full_cnt)

return dq_write_order, dq_write_order_full


def compute_dq_write_order_from_block_mask(
block_mask,
spt: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
(
_seq_q,
_seq_k,
kv_mask_cnt,
kv_mask_idx,
full_kv_cnt,
full_kv_idx,
q_mask_cnt,
q_mask_idx,
full_q_cnt,
full_q_idx,
*_,
) = block_mask.as_tuple()
return compute_dq_write_order(
kv_mask_cnt,
kv_mask_idx,
full_kv_cnt,
full_kv_idx,
q_mask_cnt,
q_mask_idx,
full_q_cnt,
full_q_idx,
spt=spt,
)


def get_sparse_q_block_size(
Expand Down Expand Up @@ -110,6 +246,25 @@ def _check_and_expand_block(
return expanded_cnt, expanded_idx


def _check_and_expand_metadata_tensor(
name: str,
tensor: torch.Tensor | None,
expected_shape: Tuple[int, ...],
context: str | None,
hint: str | Callable[[], str] | None,
device: torch.device,
) -> torch.Tensor | None:
if tensor is None:
return None
if tensor.dtype != torch.int32:
raise ValueError(f"{name} must have dtype torch.int32")
if tensor.device != device:
raise ValueError(f"{name} must be on the same device as block sparse tensors")
if not tensor.is_cuda:
raise ValueError(f"{name} must live on CUDA")
return _expand_sparsity_tensor(tensor, expected_shape, name, context, hint)


def get_block_sparse_expected_shapes(
batch_size: int,
num_head: int,
Expand Down Expand Up @@ -279,12 +434,37 @@ def normalize_block_sparse_tensors(
if full_cnt is not None and mask_cnt.device != full_cnt.device:
raise ValueError("All block sparse tensors must be on the same device")

dq_write_order = _check_and_expand_metadata_tensor(
"dq_write_order",
tensors.dq_write_order,
tuple(mask_idx.shape),
context,
hint,
mask_cnt.device,
)
dq_write_order_full = _check_and_expand_metadata_tensor(
"dq_write_order_full",
tensors.dq_write_order_full,
tuple(full_idx.shape) if full_idx is not None else expected_index_shape,
context,
hint,
mask_cnt.device,
)
spt = tensors.spt
if spt is not None and not isinstance(spt, bool):
raise ValueError("spt must be a bool when provided")
if spt is not None and dq_write_order is None:
raise ValueError("spt requires dq_write_order to be provided")

return BlockSparseTensorsTorch(
mask_block_cnt=mask_cnt,
mask_block_idx=mask_idx,
full_block_cnt=full_cnt,
full_block_idx=full_idx,
block_size=tensors.block_size,
dq_write_order=dq_write_order,
dq_write_order_full=dq_write_order_full,
spt=spt,
)


Expand Down Expand Up @@ -316,6 +496,8 @@ def get_block_sparse_broadcast_pattern(
tensors.mask_block_idx,
tensors.full_block_cnt,
tensors.full_block_idx,
tensors.dq_write_order,
tensors.dq_write_order_full,
):
if tensor is not None:
patterns.append(get_broadcast_dims(tensor))
Expand Down Expand Up @@ -423,37 +605,30 @@ def to_cute_block_sparse_tensors(
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
if not is_block_sparsity_enabled(tensors):
return None

(
mask_block_cnt,
mask_block_idx,
full_block_cnt,
full_block_idx,
*_,
) = tensors

(
mask_block_cnt_tensor,
mask_block_idx_tensor,
) = [
mask_block_cnt_tensor, mask_block_idx_tensor = [
to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
for t in (mask_block_cnt, mask_block_idx)
for t in (tensors.mask_block_cnt, tensors.mask_block_idx)
]
(
full_block_cnt_tensor,
full_block_idx_tensor,
) = [
full_block_cnt_tensor, full_block_idx_tensor = [
to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
if t is not None
else None
for t in (tensors.full_block_cnt, tensors.full_block_idx)
]
dq_write_order_tensor, dq_write_order_full_tensor = [
to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
if t is not None
else None
for t in (full_block_cnt, full_block_idx)
for t in (tensors.dq_write_order, tensors.dq_write_order_full)
]

return BlockSparseTensors(
mask_block_cnt_tensor,
mask_block_idx_tensor,
full_block_cnt_tensor,
full_block_idx_tensor,
dq_write_order_tensor,
dq_write_order_full_tensor,
)


Expand Down
Loading