diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 50d4f5e4cc0..c0ba457b129 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -68,6 +68,43 @@ def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> N # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound) +@cute.jit +def mask_r2p_dual_bound( + X: cute.Tensor, + col_limit_left: Int32, # Inclusive lower bound + col_limit_right: Int32, # Exclusive upper bound +) -> None: + """ + Dual-bound masking using two bitmasks for SM100, following mask_r2p. + Masks elements where: NOT (col_limit_left <= col < col_limit_right) + + Uses bit manipulation to create a range mask: + mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1 + mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1 + mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1 + """ + ncol = const_expr(cute.size(X.shape)) + + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + right_s = max(col_limit_right - s * 24, 0) + left_s = max(col_limit_left - s * 24, 0) + + # otherwise cute dsl complains about python int too large to convert into c long + right_s = min(right_s, 24) + left_s = min(left_s, 24) + + # bits (right-1)..left are 1 + mask_right = (1 << right_s) - 1 + mask_left = (1 << left_s) - 1 + mask_range = mask_right & ~mask_left + + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask_range & (1 << i)) + c = s * 24 + i + X[c] = X[c] if in_bound else -Float32.inf + + @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -444,14 +481,18 @@ def apply_mask_sm100( if const_expr(self.window_size_left is not None) else 0 ) - # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) - for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): - col_idx = tScS_t2r[i][1] - acc_S[i] = ( - -Float32.inf - if col_idx >= col_limit_right or col_idx < col_limit_left - else acc_S[i] - ) + if const_expr(not r2p): + # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): + col_idx = tScS_t2r[i][1] + acc_S[i] = ( + -Float32.inf + if col_idx >= col_limit_right or col_idx < col_limit_left + else acc_S[i] + ) + else: + # XOR-based R2P dual bound masking + mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right) @cute.jit def apply_mask_sm100_transposed(