Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,18 @@ def __init__(
self.score_mod = score_mod
self.mask_mod = mask_mod
self.qk_acc_dtype = Float32
self.vec_size: cutlass.Constexpr = getattr(
self.score_vec_size: cutlass.Constexpr = getattr(
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
)
if self.vec_size > 2:
if self.score_vec_size > 2:
raise ValueError(
f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 "
f"score_mod vec_size {self.score_vec_size} not supported on Sm80/90/120 "
"due to accumulator thread ownership pattern."
)
self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1)
if self.mask_vec_size > 1:
raise ValueError(
f"mask_mod vec_size {self.mask_vec_size} not supported on Sm80/90/120 "
"due to accumulator thread ownership pattern."
)
self.arch = BaseDSL._get_dsl().get_arch_enum()
Expand Down Expand Up @@ -1211,7 +1217,7 @@ def apply_score_mod(
batch_idx,
head_idx,
softmax_scale,
self.vec_size,
self.score_vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
Expand Down
6 changes: 4 additions & 2 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ def __init__(
)
self.score_mod = score_mod
self.mask_mod = mask_mod
self.vec_size: cutlass.Constexpr = getattr(
self.score_vec_size: cutlass.Constexpr = getattr(
score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
)
self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1)
# Does S1 need to wait for S0 to finish
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f
Expand Down Expand Up @@ -1943,6 +1944,7 @@ def softmax_loop(
batch_idx=batch_idx,
head_idx=head_idx,
aux_tensors=aux_tensors,
vec_size=self.mask_vec_size,
)

# Recompute fastdiv_mods if necessary
Expand Down Expand Up @@ -3107,7 +3109,7 @@ def apply_score_mod(
batch_idx,
head_idx,
softmax.softmax_scale,
self.vec_size,
self.score_vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
Expand Down
6 changes: 4 additions & 2 deletions flash_attn/cute/flash_fwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ def kernel(
blocksparse_tensors.cu_total_m_blocks if blocksparse_tensors is not None else None
),
mCuBlockIdxOffsets=(
blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None
blocksparse_tensors.cu_block_idx_offsets
if blocksparse_tensors is not None
else None
),
# Don't need to pass in tile_mn because we won't access offset_padded
)
Expand Down Expand Up @@ -1507,7 +1509,7 @@ def apply_score_mod(
batch_idx,
head_idx,
softmax_scale,
self.vec_size,
self.score_vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
Expand Down
287 changes: 242 additions & 45 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:
return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep)


@cute.jit
def apply_packed_mask_chunk(
X: cute.Tensor,
chunk_idx: cutlass.Constexpr[int],
mask: Uint32,
) -> None:
"""Apply one 32-bit keep mask to one 32-column chunk.

The one-iteration chunk loop keeps the same lowering pattern as mask_r2p_lambda.
"""
ncol = const_expr(cute.size(X.shape))
col_base = chunk_idx * MASK_R2P_CHUNK_SIZE
for s in cutlass.range_constexpr(1):
for i in cutlass.range_constexpr(
min(MASK_R2P_CHUNK_SIZE, ncol - col_base - s * MASK_R2P_CHUNK_SIZE)
):
in_bound = cutlass.Boolean(mask & (Uint32(1) << i))
c = col_base + s * MASK_R2P_CHUNK_SIZE + i
X[c] = X[c] if in_bound else -Float32.inf


@dataclass(frozen=True)
class AttentionMask:
tile_m: cutlass.Constexpr[int]
Expand Down Expand Up @@ -369,6 +390,192 @@ def mask_gen_fn(s: int) -> Uint32:
else acc_S_mn[r, c]
)

@cute.jit
def apply_mask_mod_sm100_scalar(
self,
acc_S: cute.Tensor,
tScS_t2r: cute.Tensor,
m_block: Int32,
n_block: Int32,
mask_seqlen: cutlass.Constexpr[bool],
mask_mod: cutlass.Constexpr[Callable],
batch_idx: Int32,
head_idx: Int32,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
head_divmod=None,
check_q_boundary: bool = False,
) -> None:
"""Apply a scalar FlexAttention mask_mod to an SM100 accumulator fragment.

Each accumulator lane calls mask_mod once with logical (batch, head, q, kv)
indices. Pack-GQA rows are converted back to logical q/head indices before
the call. When aux tensors are present, indices are wrapped with fastdiv so
mask_mod never reads outside the per-example auxiliary storage.
"""
has_fastdiv = const_expr(
fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None
)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
ncol = const_expr(cute.size(tScS_t2r.shape))

for i in cutlass.range_constexpr(ncol):
row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
global_row = row_coord + m_block * self.tile_m
global_col = col_coord + n_block * self.tile_n

if const_expr(self.qhead_per_kvhead_packgqa != 1):
assert head_divmod is not None
mask_row, head_offset = divmod(global_row, head_divmod)
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
else:
head_idx_for_mod = head_idx
mask_row = global_row

mask_row_for_mod = mask_row
if const_expr(has_fastdiv and aux_tensors is not None):
if check_q_boundary:
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
global_col_for_mod = global_col
if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
_, global_col_for_mod = divmod(global_col, fastdiv_mods[1])

head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
head_idx_ssa,
mask_row_ssa,
kv_idx_ssa,
self.seqlen_info,
aux_tensors,
)
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
acc_S[i] = acc_S[i] if cond else -Float32.inf
if const_expr(mask_seqlen):
acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
if check_q_boundary:
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]

@cute.jit
def apply_mask_mod_sm100_vector(
self,
acc_S: cute.Tensor,
tScS_t2r: cute.Tensor,
m_block: Int32,
n_block: Int32,
mask_seqlen: cutlass.Constexpr[bool],
mask_mod: cutlass.Constexpr[Callable],
batch_idx: Int32,
head_idx: Int32,
vec_size: cutlass.Constexpr[int],
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
head_divmod=None,
check_q_boundary: bool = False,
) -> None:
"""Apply a vectorized FlexAttention mask_mod to an SM100 fragment.

mask_mod receives vec_size adjacent KV indices for one logical q row and
returns bit-packed Uint32 keep masks. Low bits correspond to lower KV
indices. The packed masks are combined with sequence-boundary checks, then
applied in 32-column chunks so the final masking lowers to R2P.
"""
has_fastdiv = const_expr(
fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None
)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
ncol = const_expr(cute.size(tScS_t2r.shape))
mask_vals_per_apply = const_expr(max(1, vec_size // 32))
calls_per_apply = const_expr(max(1, 32 // vec_size))
n_calls = const_expr(cute.ceil_div(ncol, vec_size))
mask_vals = cute.make_rmem_tensor(mask_vals_per_apply, dtype=cutlass.Uint32)

# Accumulate enough vector mask_mod calls to produce 32-bit chunks that
# apply_packed_mask_chunk can lower to R2P.
for s in cutlass.range_constexpr(n_calls):
if const_expr(s % calls_per_apply == 0):
for c in cutlass.range_constexpr(mask_vals_per_apply):
mask_vals[c] = cutlass.Uint32(0)
i = s * vec_size
row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
global_row = row_coord + m_block * self.tile_m
global_col = col_coord + n_block * self.tile_n
if const_expr(self.qhead_per_kvhead_packgqa != 1):
assert head_divmod is not None
mask_row, head_offset = divmod(global_row, head_divmod)
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
else:
head_idx_for_mod = head_idx
mask_row = global_row
mask_row_for_mod = mask_row
if const_expr(has_fastdiv and aux_tensors is not None):
if check_q_boundary:
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])

head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32).broadcast_to(
(vec_size,)
)
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32).broadcast_to(
(vec_size,)
)
batch_idx_ssa_call = batch_idx_ssa.broadcast_to((vec_size,))
kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)

# Build the per-lane KV indices for this vectorized mask_mod call.
for j in cutlass.range_constexpr(min(vec_size, ncol - i)):
col_j_coord = tScS_t2r[i + j][1] if not self.swap_AB else tScS_t2r[i + j][0]
col_j_global = col_j_coord + n_block * self.tile_n
col_j_for_mod = col_j_global
if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
_, col_j_for_mod = divmod(col_j_global, fastdiv_mods[1])
kv_idx_vec[j] = col_j_for_mod
kv_idx_ssa = kv_idx_vec.load()

# mask_value is already bit-packed by the vectorized mask_mod.
mask_value = mask_mod(
batch_idx_ssa_call,
head_idx_ssa,
mask_row_ssa,
kv_idx_ssa,
self.seqlen_info,
aux_tensors,
)

# For vec_size < 32, multiple mask_mod calls fill one R2P chunk.
bit_offset = const_expr((s % calls_per_apply) * vec_size)
seqlen_thresh_call = (
self.seqlen_k - global_col if const_expr(mask_seqlen) else cutlass.Int32(0)
)
q_in_bounds = mask_row < self.seqlen_q if check_q_boundary else cutlass.Boolean(True)
for c in cutlass.range_constexpr(mask_vals_per_apply):
mask_val = mask_value[c]
if const_expr(vec_size < 32):
lane_keep = utils.shr_u32(
cutlass.Uint32(0xFFFFFFFF),
cutlass.Uint32(32 - vec_size),
)
mask_val = mask_val & lane_keep
if const_expr(mask_seqlen):
mask_val = mask_val & r2p_bitmask_below(seqlen_thresh_call, c)
if check_q_boundary:
mask_val = mask_val if q_in_bounds else cutlass.Uint32(0)
mask_vals[c] = mask_vals[c] | (mask_val << bit_offset)

# Apply only when the 32-bit chunk is complete, or at the tile tail.
is_last_in_apply = const_expr(s % calls_per_apply == calls_per_apply - 1)
is_last_overall = const_expr(s == n_calls - 1)
if const_expr(is_last_in_apply or is_last_overall):
apply_idx = s // calls_per_apply
for c in cutlass.range_constexpr(mask_vals_per_apply):
chunk_idx = apply_idx * mask_vals_per_apply + c
# Skip packed chunks that start past the accumulator fragment.
if const_expr(chunk_idx * 32 < ncol):
apply_packed_mask_chunk(acc_S, chunk_idx, mask_vals[c])

@cute.jit
def apply_mask_sm100(
self,
Expand All @@ -386,6 +593,7 @@ def apply_mask_sm100(
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
head_divmod=None,
vec_size: cutlass.Constexpr[int] = 1,
check_q_boundary: bool = False,
r2p: bool = True,
rBitmask: Optional[cute.Tensor] = None,
Expand Down Expand Up @@ -429,54 +637,43 @@ def apply_mask_sm100(
)

elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
# Block sparse case w/ mask_mod
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
and fastdiv_mods[1] is not None
# FlexAttention mask_mod vectorization is gated on `mask_mod.__vec_size__`.
# vec_size == 1 returns a scalar Boolean. vec_size > 1 returns packed
# Uint32 mask fragments: one word per 32 evaluated columns.
assert vec_size % 32 == 0 or 32 % vec_size == 0, (
"vec_size must divide 32 or be a multiple of 32"
)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)

ncol = const_expr(cute.size(tScS_t2r.shape))
for i in cutlass.range_constexpr(ncol):
row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
global_row = row_coord + m_block * self.tile_m
global_col = col_coord + n_block * self.tile_n

if const_expr(self.qhead_per_kvhead_packgqa != 1):
assert head_divmod is not None
mask_row, head_offset = divmod(global_row, head_divmod)
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
else:
head_idx_for_mod = head_idx
mask_row = global_row

mask_row_for_mod = mask_row
if const_expr(has_fastdiv and aux_tensors is not None):
if check_q_boundary:
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
global_col_for_mod = global_col
if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
_, global_col_for_mod = divmod(global_col, fastdiv_mods[1])

head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
head_idx_ssa,
mask_row_ssa,
kv_idx_ssa,
self.seqlen_info,
if const_expr(vec_size == 1):
self.apply_mask_mod_sm100_scalar(
acc_S,
tScS_t2r,
m_block,
n_block,
mask_seqlen,
mask_mod,
batch_idx,
head_idx,
aux_tensors,
fastdiv_mods,
head_divmod,
check_q_boundary,
)
else:
self.apply_mask_mod_sm100_vector(
acc_S,
tScS_t2r,
m_block,
n_block,
mask_seqlen,
mask_mod,
batch_idx,
head_idx,
vec_size,
aux_tensors,
fastdiv_mods,
head_divmod,
check_q_boundary,
)
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
acc_S[i] = acc_S[i] if cond else -Float32.inf
if const_expr(mask_seqlen):
acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
if check_q_boundary:
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]

else: # Causal or local
causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q
Expand Down
Loading