Skip to content
57 changes: 49 additions & 8 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,43 @@ def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> N
# cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)


@cute.jit
def mask_r2p_dual_bound(
X: cute.Tensor,
col_limit_left: Int32, # Inclusive lower bound
col_limit_right: Int32, # Exclusive upper bound
) -> None:
"""
Dual-bound masking using two bitmasks for SM100, following mask_r2p.
Masks elements where: NOT (col_limit_left <= col < col_limit_right)

Uses bit manipulation to create a range mask:
mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1
mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1
mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1
"""
ncol = const_expr(cute.size(X.shape))

for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
right_s = max(col_limit_right - s * 24, 0)
left_s = max(col_limit_left - s * 24, 0)

# otherwise cute dsl complains about python int too large to convert into c long
right_s = min(right_s, 24)
left_s = min(left_s, 24)

# bits (right-1)..left are 1
mask_right = (1 << right_s) - 1
mask_left = (1 << left_s) - 1
mask_range = mask_right & ~mask_left

# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
in_bound = cutlass.Boolean(mask_range & (1 << i))
c = s * 24 + i
X[c] = X[c] if in_bound else -Float32.inf


@dataclass(frozen=True)
class AttentionMask:
tile_m: cutlass.Constexpr[int]
Expand Down Expand Up @@ -444,14 +481,18 @@ def apply_mask_sm100(
if const_expr(self.window_size_left is not None)
else 0
)
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
col_idx = tScS_t2r[i][1]
acc_S[i] = (
-Float32.inf
if col_idx >= col_limit_right or col_idx < col_limit_left
else acc_S[i]
)
if const_expr(not r2p):
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
col_idx = tScS_t2r[i][1]
acc_S[i] = (
-Float32.inf
if col_idx >= col_limit_right or col_idx < col_limit_left
else acc_S[i]
)
else:
# XOR-based R2P dual bound masking
mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right)

@cute.jit
def apply_mask_sm100_transposed(
Expand Down