Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,26 +300,20 @@ def _flash_attn_fwd(
seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size))
num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
total_mblocks = batch_size * num_head_kv * num_m_blocks
num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
num_SMs = torch.cuda.get_device_properties(device).multi_processor_count
if num_splits < 1:
num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
num_splits = num_splits_heuristic(
total_mblocks,
torch.cuda.get_device_properties(device).multi_processor_count,
num_n_blocks,
128,
)
num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)

# SplitKV uses float32 partial output, which doubles the O buffer size
# in shared memory, causing OOM for diff-headdim (192, 128)
if compute_capability in [10, 11] and head_dim != head_dim_v and num_splits > 1:
n_block_size = 64
num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
num_splits = num_splits_heuristic(
total_mblocks,
torch.cuda.get_device_properties(device).multi_processor_count,
num_n_blocks,
128,
)
if num_n_blocks >= 64:
n_block_size = 64
num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)
else:
num_splits = 1

is_split_kv = num_splits > 1
if is_split_kv:
Expand Down