diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 669a731159a..f01a6921ffd 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -300,26 +300,20 @@ def _flash_attn_fwd( seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size)) num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective total_mblocks = batch_size * num_head_kv * num_m_blocks + num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size + num_SMs = torch.cuda.get_device_properties(device).multi_processor_count if num_splits < 1: - num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size - num_splits = num_splits_heuristic( - total_mblocks, - torch.cuda.get_device_properties(device).multi_processor_count, - num_n_blocks, - 128, - ) + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) # SplitKV uses float32 partial output, which doubles the O buffer size # in shared memory, causing OOM for diff-headdim (192, 128) if compute_capability in [10, 11] and head_dim != head_dim_v and num_splits > 1: - n_block_size = 64 - num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size - num_splits = num_splits_heuristic( - total_mblocks, - torch.cuda.get_device_properties(device).multi_processor_count, - num_n_blocks, - 128, - ) + if num_n_blocks >= 64: + n_block_size = 64 + num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + else: + num_splits = 1 is_split_kv = num_splits > 1 if is_split_kv: