Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from cutlass.cute.runtime import from_dlpack


def ceildiv(a: int, b: int) -> int:
return (a + b - 1) // b


# placeholder
Config = type("Config", (), {})

Expand Down Expand Up @@ -78,6 +82,26 @@ def _check_and_expand_block(
return expanded_cnt, expanded_idx


def get_block_sparse_expected_shapes(
batch_size: int,
num_head: int,
seqlen_q: int,
seqlen_k: int,
m_block_size: int,
n_block_size: int,
compute_capability: int,
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
"""Return (expected_count_shape, expected_index_shape) for block sparse normalization."""
# TODO: This multiplier should really be q_stage, wire up in later PR
# 1 cta handles 2*tile_m rows on SM100
m_block_size_effective = 2 * m_block_size if compute_capability == 10 else m_block_size
expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective)
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
expected_count_shape = (batch_size, num_head, expected_m_blocks)
expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
return expected_count_shape, expected_index_shape


def normalize_block_sparse_tensors(
tensors: BlockSparseTensorsTorch,
*,
Expand Down Expand Up @@ -205,8 +229,8 @@ def _compute_sparsity(
config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes block sparsity for fixed-length sequences."""
n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m
n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n
n_blocks_q = ceildiv(config.seqlen_q, config.tile_m)
n_blocks_k = ceildiv(config.seqlen_k, config.tile_n)

# Pre-allocate output tensors
full_block_cnt = torch.zeros(
Expand Down Expand Up @@ -325,12 +349,12 @@ def _compute_varlen_sparsity(
max_m_blocks = 0
for seq_idx in range(config.batch_size):
seq_len_q = (cu_seqlens_q[seq_idx + 1] - cu_seqlens_q[seq_idx]).item()
n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m
n_blocks_q = ceildiv(seq_len_q, config.tile_m)
max_m_blocks = max(max_m_blocks, n_blocks_q)

# The number of K blocks is determined by the total length of all sequences.
total_k_len = cu_seqlens_k[-1].item()
max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n
max_n_blocks = ceildiv(total_k_len, config.tile_n)

# Pre-allocate padded output tensors
full_block_cnt = torch.zeros(
Expand Down Expand Up @@ -360,8 +384,8 @@ def _compute_varlen_sparsity(
seq_end_k = cu_seqlens_k[seq_idx + 1].item()
seq_len_k = seq_end_k - seq_start_k

n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m
n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n
n_blocks_q = ceildiv(seq_len_q, config.tile_m)
n_blocks_k = ceildiv(seq_len_k, config.tile_n)

# Global block indices are relative to the start of the entire batch tensor
first_m_block_global = seq_start_q // config.tile_m
Expand Down
Loading