Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
714 changes: 714 additions & 0 deletions flash_attn/cute/benchmark_mask_mod.py

Large diffs are not rendered by default.

372 changes: 372 additions & 0 deletions flash_attn/cute/block_sparsity.py

Large diffs are not rendered by default.

655 changes: 534 additions & 121 deletions flash_attn/cute/flash_fwd.py

Large diffs are not rendered by default.

94 changes: 85 additions & 9 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.

# Supported features:
# - BF16 & FP16 dtype
Expand Down Expand Up @@ -73,7 +74,12 @@ def _flash_attn_fwd(
num_threads: int = 384,
pack_gqa: Optional[bool] = None,
_compute_capability: Optional[int] = None,
score_mod: Callable | None = None,
score_mod: Optional[Callable] = None,
mask_mod: Optional[Callable] = None,
full_block_cnt: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some postfacto comments feel free to ignore:

can we maybe put full and mask cnts/indices ina a tuple so that its easier to pass around

full_block_idx: Optional[torch.Tensor] = None,
mask_block_cnt: Optional[torch.Tensor] = None,
mask_block_idx: Optional[torch.Tensor] = None,
return_lse: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -135,7 +141,22 @@ def _flash_attn_fwd(
if learnable_sink is not None:
assert learnable_sink.shape == (num_head,)
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device"
for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]:
if t is not None:
assert t.dtype == torch.int32, "blocksparse mask tensors must be int32"
assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous"
assert all(
t is None or t.is_cuda
for t in (
q, k, v,
cu_seqlens_q, cu_seqlens_k,
seqused_q, seqused_k,
page_table,
learnable_sink,
full_block_cnt, full_block_idx,
mask_block_cnt, mask_block_idx,
)
), "inputs must be on CUDA device"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
Expand Down Expand Up @@ -183,6 +204,13 @@ def _flash_attn_fwd(
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
]
page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None

full_block_cnt_tensor = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if full_block_cnt is not None else None
full_block_idx_tensor = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if full_block_idx is not None else None
mask_block_cnt_tensor = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if mask_block_cnt is not None else None
mask_block_idx_tensor = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if mask_block_idx is not None else None


if causal:
window_size_right = 0
local = window_size_left is not None or window_size_right is not None
Expand All @@ -202,22 +230,44 @@ def _flash_attn_fwd(
# TODO: fix the varlen case
if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None):
pack_gqa = False


# hash score and mask mods for compile cache
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else None
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else None

if softcap is not None:
assert score_mod is None, "softcap and score_mod cannot be used together"
score_mod = utils.create_softcap_scoremod(softcap)

is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None
use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None
if score_mod is not None:
is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None
if is_varlen:
raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.")
if pack_gqa:
raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a rebase bug, I added support here: #1937


if mask_mod is not None:
if not use_block_sparsity:
raise NotImplementedError("mask_mod requires the use of block sparsity. This will be fixed in a future PR.")
if is_varlen:
raise NotImplementedError("mask_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.")
if pack_gqa:
raise NotImplementedError("mask_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.")

if use_block_sparsity:
if is_varlen:
raise NotImplementedError("Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR.")
if pack_gqa:
raise NotImplementedError("Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR.")

cute_buffers = None
if buffers is not None:
cute_buffers = [from_dlpack(buf) for buf in buffers]

compile_key = (
dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, utils.hash_callable(score_mod) if score_mod is not None else None,
dtype, head_dim, head_dim_v, qhead_per_kvhead, causal,
score_mod_hash, mask_mod_hash,
buffers is not None,
lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None,
page_table is not None,
Expand Down Expand Up @@ -245,6 +295,9 @@ def _flash_attn_fwd(
num_stages=2,
num_threads=num_threads,
Q_in_regs=False,
intra_wg_overlap=True,
mma_pv_is_rs=True,
mask_mod=mask_mod,
score_mod=score_mod,
has_buffers=buffers is not None,
)
Expand All @@ -264,18 +317,21 @@ def _flash_attn_fwd(
else:
raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x")
# TODO: check @can_implement
# TODO caching for buffers; cute_buffers
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream,
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor,
page_table_tensor,
window_size_left, window_size_right, learnable_sink_tensor, cute_buffers,
window_size_left, window_size_right, learnable_sink_tensor,
full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor,
cute_buffers,
)
_flash_attn_fwd.compile_cache[compile_key](
q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream,
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor,
page_table_tensor,
window_size_left, window_size_right, learnable_sink_tensor, cute_buffers
window_size_left, window_size_right, learnable_sink_tensor,
full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor,
cute_buffers,
)
return out, lse

Expand Down Expand Up @@ -591,6 +647,11 @@ def forward(
learnable_sink: Optional[torch.Tensor] = None,
softcap: float = 0.0,
pack_gqa: Optional[bool] = None,
mask_mod: Optional[Callable] = None,
full_block_cnt: Optional[torch.Tensor] = None,
full_block_idx: Optional[torch.Tensor] = None,
mask_block_cnt: Optional[torch.Tensor] = None,
mask_block_idx: Optional[torch.Tensor] = None,
):
out, lse = _flash_attn_fwd(
q,
Expand All @@ -603,6 +664,11 @@ def forward(
learnable_sink=learnable_sink,
softcap=softcap,
pack_gqa=pack_gqa,
mask_mod=mask_mod,
full_block_cnt=full_block_cnt,
full_block_idx=full_block_idx,
mask_block_cnt=mask_block_cnt,
mask_block_idx=mask_block_idx,
)
ctx.save_for_backward(q, k, v, out, lse)
ctx.softmax_scale = softmax_scale
Expand Down Expand Up @@ -706,6 +772,11 @@ def flash_attn_func(
learnable_sink: Optional[torch.Tensor] = None,
softcap: float = 0.0,
pack_gqa: Optional[bool] = None,
mask_mod: Optional[Callable] = None,
full_block_cnt: Optional[torch.Tensor] = None,
full_block_idx: Optional[torch.Tensor] = None,
mask_block_cnt: Optional[torch.Tensor] = None,
mask_block_idx: Optional[torch.Tensor] = None,
):
return FlashAttnFunc.apply(
q,
Expand All @@ -717,6 +788,11 @@ def flash_attn_func(
learnable_sink,
softcap,
pack_gqa,
mask_mod,
full_block_cnt,
full_block_idx,
mask_block_cnt,
mask_block_idx,
)


Expand Down Expand Up @@ -973,4 +1049,4 @@ def flash_attn_combine(
lse = None

_flash_attn_fwd_combine(out_partial, lse_partial, out, lse)
return out, lse
return out, lse
64 changes: 52 additions & 12 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2025, Tri Dao.

from typing import Optional
from typing import Optional, Callable
from dataclasses import dataclass

import cutlass
Expand All @@ -9,7 +9,6 @@

import flash_attn.cute.utils as utils


@cute.jit
def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None:
# Bit manipulation, compiles down to the R2P instruction
Expand Down Expand Up @@ -39,7 +38,6 @@ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = Fal
for r in cutlass.range_constexpr(cute.size(X.shape[0])):
X[r, c] = X[r, c] if in_bound else -Float32.inf


@dataclass(frozen=True)
class AttentionMask:
tile_m: cutlass.Constexpr[int]
Expand All @@ -55,12 +53,16 @@ class AttentionMask:
def apply_mask(
self,
acc_S: cute.Tensor,
m_block: Int32,
n_block: Int32,
batch_idx: cutlass.Int32,
head_idx: cutlass.Int32,
m_block: cutlass.Int32,
n_block: cutlass.Int32,
thr_mma: cute.TiledMma,
mask_seqlen: cutlass.Constexpr[bool],
mask_causal: cutlass.Constexpr[bool],
mask_local: cutlass.Constexpr[bool] = False,
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
buffers: Optional[list[cute.Tensor]] = None,
) -> None:
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB)
Expand All @@ -76,17 +78,55 @@ def apply_mask(
COL = 1 if const_expr(not self.swap_AB) else 0
thr_col_offset = tScS_mn[0][COL]
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
if const_expr(not mask_causal and not mask_local):
if const_expr(not mask_causal and not mask_local and mask_mod is None):
if const_expr(mask_seqlen):
# The compiler now choses not to use R2P
r2p = const_expr(False and not self.swap_AB)
if const_expr(not r2p):
# traverse column index.
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
else:
mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90)

elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # FlexAttention mask mod
nrow = const_expr(cute.size(tScS_mn.shape[0]))
ncol = const_expr(cute.size(tScS_mn.shape[1]))
thr_col_offset = tScS_mn[0, 0][1]

for r in cutlass.range_constexpr(nrow):
global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m

for col in cutlass.range_constexpr(ncol):
col_idx_local = t0ScS_mn[0, col][1]
# Convert to absolute column index
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n

cond = cutlass.Boolean(
mask_mod(
batch_idx,
head_idx,
tScS_mn[r, 0][0] + m_block * self.tile_m,
thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n,
self.seqlen_q,
self.seqlen_k,
buffers,
)
)
if const_expr(mask_seqlen):
out_of_bounds = (global_row_idx >= self.seqlen_q) or (
global_col_idx >= self.seqlen_k
)
if out_of_bounds:
acc_S_mn[r, col] = -cutlass.Float32.inf
else:
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
else:
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf


else: # Causal or local
if const_expr(not self.swap_AB):
# If PackGQA, we split the work of compute divmod among threads in the same row
Expand Down Expand Up @@ -303,22 +343,22 @@ def apply_mask_sm100_transposed(
tidx = cute.arch.thread_idx()[0] % 128

seqlenk_row_limit = self.seqlen_k - n_block * self.tile_n
if cutlass.const_expr(not mask_causal and not mask_local):
if cutlass.const_expr(mask_seqlen):
ncol = cutlass.const_expr(cute.size(tScS_t2r.shape))
if const_expr(not mask_causal and not mask_local):
if const_expr(mask_seqlen):
ncol = const_expr(cute.size(tScS_t2r.shape))
if tScS_t2r[0][0] >= seqlenk_row_limit:
for i in cutlass.range(ncol, unroll_full=True):
acc_S[i] = -cutlass.Float32.inf
else: # Causal or local
causal_row_offset = (self.seqlen_q - self.seqlen_k - 1) - m_block * self.tile_m
row_idx = tScS_t2r[0][0] + n_block * self.tile_n

if cutlass.const_expr(mask_causal):
if const_expr(mask_causal):
col_limit_left = row_idx + causal_row_offset
ncol = cutlass.const_expr(cute.size(tScS_t2r.shape))
ncol = const_expr(cute.size(tScS_t2r.shape))
# if tidx == 32 and wg_idx == 1:
# cute.printf("row idx = {}, causal_row_offset = {}, col_limit_left = {}, first column = {}, last column = {} ", row_idx, causal_row_offset, col_limit_left, tScS_t2r[0][1], tScS_t2r[ncol - 1][1])
if cutlass.const_expr(mask_seqlen):
if const_expr(mask_seqlen):
if tScS_t2r[0][0] >= seqlenk_row_limit:
col_limit_left = self.tile_m
for i in cutlass.range(ncol, unroll_full=True):
Expand Down
Loading