Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 17 additions & 41 deletions flash_attn/cute/benchmark_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
create_cute_sliding_window_mask,
create_flex_sliding_window_mask,
)
from block_sparsity import compute_block_sparsity
from flash_attn.cute.block_sparsity import (
compute_block_sparsity,
BlockSparseTensorsTorch,
to_cute_block_sparse_tensors,
)


@dataclass
Expand Down Expand Up @@ -265,10 +269,12 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]:
)

if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]):
tensors["full_block_cnt"] = full_cnt.contiguous()
tensors["full_block_idx"] = full_idx.contiguous()
tensors["mask_block_cnt"] = mask_cnt.contiguous()
tensors["mask_block_idx"] = mask_idx.contiguous()
tensors["block_sparse_tensors"] = BlockSparseTensorsTorch(
mask_block_cnt=mask_cnt.contiguous(),
mask_block_idx=mask_idx.contiguous(),
full_block_cnt=full_cnt.contiguous(),
full_block_idx=full_idx.contiguous(),
)

if config.verbose:
total_full = full_cnt.sum().item()
Expand Down Expand Up @@ -373,33 +379,9 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]
else None
)

# Block sparsity tensors
full_block_cnt_cute = (
from_dlpack(tensors["full_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=2
)
if "full_block_cnt" in tensors
else None
)
full_block_idx_cute = (
from_dlpack(tensors["full_block_idx"].detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=3
)
if "full_block_idx" in tensors
else None
)
mask_block_cnt_cute = (
from_dlpack(tensors["mask_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=2
)
if "mask_block_cnt" in tensors
else None
)
mask_block_idx_cute = (
from_dlpack(tensors["mask_block_idx"].detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=3
)
if "mask_block_idx" in tensors
blocksparse_tensors_cute = (
to_cute_block_sparse_tensors(tensors["block_sparse_tensors"])
if "block_sparse_tensors" in tensors
else None
)

Expand Down Expand Up @@ -436,11 +418,8 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]
None, # page_table
window_left_cute,
window_right_cute,
learnable_sink_cute, # learnable_sink
full_block_cnt_cute,
full_block_idx_cute,
mask_block_cnt_cute,
mask_block_idx_cute,
learnable_sink_cute,
blocksparse_tensors_cute,
aux_tensors_cute,
# None,
)
Expand All @@ -461,10 +440,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]
window_left_cute,
window_right_cute,
learnable_sink_cute,
full_block_cnt_cute,
full_block_idx_cute,
mask_block_cnt_cute,
mask_block_idx_cute,
blocksparse_tensors_cute,
aux_tensors_cute,
# None,
)
Expand Down
81 changes: 80 additions & 1 deletion flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,92 @@
by a more robust preprocessing kernel in the future.
"""

from typing import Tuple, Optional, Callable, List
from typing import Tuple, Optional, Callable, List, NamedTuple
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

# placeholder
Config = type("Config", (), {})


class BlockSparseTensors(NamedTuple):
mask_block_cnt: cute.Tensor
mask_block_idx: cute.Tensor
full_block_cnt: Optional[cute.Tensor]
full_block_idx: Optional[cute.Tensor]

def __new_from_mlir_values__(self, values):
return BlockSparseTensors(*values)


class BlockSparseTensorsTorch(NamedTuple):
mask_block_cnt: torch.Tensor
mask_block_idx: torch.Tensor
full_block_cnt: Optional[torch.Tensor] = None
full_block_idx: Optional[torch.Tensor] = None


def validate_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> None:
for name, cnt, idx in (
("mask", tensors.mask_block_cnt, tensors.mask_block_idx),
("full", tensors.full_block_cnt, tensors.full_block_idx),
):
if (cnt is None) != (idx is None):
raise ValueError(
f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
)
if cnt is None:
continue
if cnt.dtype != torch.int32 or idx.dtype != torch.int32:
raise ValueError(f"{name}_block tensors must have dtype torch.int32")
if cnt.device != idx.device:
raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
if not cnt.is_cuda or not idx.is_cuda:
raise ValueError(f"{name}_block tensors must live on CUDA")

if tensors.full_block_cnt is not None and tensors.mask_block_cnt is not None:
if tensors.full_block_cnt.device != tensors.mask_block_cnt.device:
raise ValueError("All block sparse tensors must be on the same device")


def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))


def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> Optional[BlockSparseTensors]:
if not is_block_sparsity_enabled(tensors):
return None

mask_block_cnt_tensor = from_dlpack(
tensors.mask_block_cnt.detach(), assumed_align=4
).mark_layout_dynamic(leading_dim=2)
mask_block_idx_tensor = from_dlpack(
tensors.mask_block_idx.detach(), assumed_align=4
).mark_layout_dynamic(leading_dim=3)
full_block_cnt_tensor = (
from_dlpack(tensors.full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=2
)
if tensors.full_block_cnt is not None
else None
)
full_block_idx_tensor = (
from_dlpack(tensors.full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=3
)
if tensors.full_block_idx is not None
else None
)

return BlockSparseTensors(
mask_block_cnt_tensor,
mask_block_idx_tensor,
full_block_cnt_tensor,
full_block_idx_tensor,
)


def compute_block_sparsity(
config: Config,
mask_mod_flex: Optional[Callable],
Expand Down
44 changes: 13 additions & 31 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from flash_attn.cute.softmax import Softmax, apply_score_mod_inner
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute import pipeline
from flash_attn.cute.pack_gqa import PackGQA
from flash_attn.cute.named_barrier import NamedBarrierFwd
Expand Down Expand Up @@ -1271,10 +1272,7 @@ def __call__(
window_size_left: Int32 | int | None = None,
window_size_right: Int32 | int | None = None,
learnable_sink: Optional[cute.Tensor] = None,
full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block)
full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block)
mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block)
mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block)
blocksparse_tensors: Optional[BlockSparseTensors] = None,
aux_tensors: Optional[list] = None,
):
"""Configures and launches the flash attention kernel.
Expand All @@ -1290,6 +1288,7 @@ def __call__(
)
)


# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
Expand Down Expand Up @@ -1325,9 +1324,8 @@ def __call__(
)
# self.num_mma_regs = 232
# self.num_producer_regs = 40
self.use_block_sparsity = const_expr(
mask_block_cnt is not None and full_block_cnt is not None
)
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)

self.use_scheduler_barrier = (
(self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128)
if const_expr(self.intra_wg_overlap)
Expand Down Expand Up @@ -1521,10 +1519,7 @@ def __call__(
window_size_left,
window_size_right,
learnable_sink,
full_block_cnt,
full_block_idx,
mask_block_cnt,
mask_block_idx,
blocksparse_tensors,
self.sQ_layout,
self.sK_layout,
self.sV_layout,
Expand Down Expand Up @@ -1571,10 +1566,7 @@ def kernel(
window_size_left: Optional[Int32],
window_size_right: Optional[Int32],
learnable_sink: Optional[cute.Tensor],
full_block_cnt: Optional[cute.Tensor],
full_block_idx: Optional[cute.Tensor],
mask_block_cnt: Optional[cute.Tensor],
mask_block_idx: Optional[cute.Tensor],
blocksparse_tensors: Optional[BlockSparseTensors],
sQ_layout: cute.ComposedLayout,
sK_layout: cute.ComposedLayout,
sV_layout: cute.ComposedLayout,
Expand Down Expand Up @@ -1698,10 +1690,7 @@ def kernel(
pipeline_k,
pipeline_v,
mbar_ptr_Q,
full_block_cnt,
full_block_idx,
mask_block_cnt,
mask_block_idx,
blocksparse_tensors,
block_info,
SeqlenInfoCls,
TileSchedulerCls,
Expand Down Expand Up @@ -1740,10 +1729,7 @@ def kernel(
SeqlenInfoCls,
AttentionMaskCls,
TileSchedulerCls,
full_block_cnt,
full_block_idx,
mask_block_cnt,
mask_block_idx,
blocksparse_tensors,
aux_tensors,
fastdiv_mods,
)
Expand All @@ -1763,10 +1749,7 @@ def load(
pipeline_k: cutlass.pipeline.PipelineAsync,
pipeline_v: cutlass.pipeline.PipelineAsync,
mbar_ptr_Q: cutlass.Pointer,
full_block_cnt: Optional[cute.Tensor],
full_block_idx: Optional[cute.Tensor],
mask_block_cnt: Optional[cute.Tensor],
mask_block_idx: Optional[cute.Tensor],
blocksparse_tensors: Optional[BlockSparseTensors],
block_info: BlockInfo,
SeqlenInfoCls: Callable,
TileSchedulerCls: Callable,
Expand Down Expand Up @@ -1852,6 +1835,7 @@ def load(
# ==========================================
# Flex Attention blocksparsity
# ==========================================
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block]
Expand Down Expand Up @@ -2033,10 +2017,7 @@ def mma(
SeqlenInfoCls: Callable,
AttentionMaskCls: Callable,
TileSchedulerCls: Callable,
full_block_cnt: Optional[cute.Tensor],
full_block_idx: Optional[cute.Tensor],
mask_block_cnt: Optional[cute.Tensor],
mask_block_idx: Optional[cute.Tensor],
blocksparse_tensors: Optional[BlockSparseTensors],
aux_tensors: Optional[list],
fastdiv_mods=None,
):
Expand Down Expand Up @@ -2263,6 +2244,7 @@ def mma(
# ==========================================
# Block sparsity
# ==========================================
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block]
Expand Down
7 changes: 2 additions & 5 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.pack_gqa import PackGQA
from flash_attn.cute import mma_sm100_desc as sm100_desc
from flash_attn.cute import blackwell_helpers as sm100_utils
Expand Down Expand Up @@ -223,10 +224,7 @@ def __call__(
window_size_left: Int32 | int | None = None,
window_size_right: Int32 | int | None = None,
learnable_sink: Optional[cute.Tensor] = None,
full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block)
full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block)
mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block)
mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block)
blocksparse_tensors: Optional[BlockSparseTensors] = None,
aux_tensors: Optional[list] = None,
):
"""Execute the Fused Multi-Head Attention operation on the provided tensors.
Expand All @@ -242,7 +240,6 @@ def __call__(
5. Grid and work scheduling computation
6. Kernel launch with appropriate parameters
"""

# setup static attributes before smem/grid/tma computation
self.q_dtype = mQ.element_type
self.k_dtype = mK.element_type
Expand Down
Loading