Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions flash_attn/cute/block_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,16 @@ def get_n_block_min_max(
def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:
m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)
m_block_min = 0
if const_expr(self.is_causal):
m_block_min = max(
m_block_min,
(n_block * self.tile_n + seqlen_info.seqlen_q - seqlen_info.seqlen_k)
// self.tile_m,
)
if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
n_idx_min = n_block * self.tile_n
m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right
m_block_min = max(m_block_min, m_idx_right // self.tile_m)
if const_expr(self.is_local and self.window_size_left is not None):
n_idx_max = (n_block + 1) * self.tile_n
m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k
m_idx_left = m_idx + self.window_size_left
m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
return m_block_min, m_block_max

@cute.jit
Expand Down
558 changes: 316 additions & 242 deletions flash_attn/cute/flash_bwd_sm100.py

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def _flash_attn_fwd(
if window_size_left is not None or window_size_right is not None:
if window_size_left is None and window_size_right == 0:
causal, local = True, False
window_size_right = None
else:
causal, local = False, True
else:
Expand Down Expand Up @@ -540,6 +541,8 @@ def _flash_attn_bwd(
softmax_scale: Optional[float] = None,
causal: bool = False,
softcap: float = 0.0,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
m_block_size: int = 64,
n_block_size: int = 128,
num_threads: int = 256,
Expand Down Expand Up @@ -575,6 +578,7 @@ def _flash_attn_bwd(
AtomLayoutNdKV = 2
AtomLayoutMdQ = 1
cluster_size = 1
assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
else:
m_block_size = 128
n_block_size = 128
Expand Down Expand Up @@ -608,6 +612,16 @@ def _flash_attn_bwd(
num_head_kv = k.shape[-2]
head_dim_v = v.shape[-1]

if causal:
window_size_right = 0
local = window_size_left is not None or window_size_right is not None
if local:
if window_size_left is None and window_size_right == 0:
causal, local = True, False
window_size_right = None
else:
causal, local = False, True

if cu_seqlens_k is None:
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
Expand Down Expand Up @@ -840,6 +854,8 @@ def _flash_attn_bwd(
head_dim_v,
qhead_per_kvhead,
causal,
window_size_left is not None,
window_size_right is not None,
softcap != 0.0,
m_block_size,
n_block_size,
Expand Down Expand Up @@ -896,6 +912,7 @@ def _flash_attn_bwd(
head_dim,
head_dim_v,
is_causal=causal,
is_local=local,
qhead_per_kvhead=qhead_per_kvhead,
# tile_m=m_block_size,
# tile_n=n_block_size,
Expand All @@ -921,6 +938,8 @@ def _flash_attn_bwd(
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
window_size_left=window_size_left,
window_size_right=window_size_right,
mdQ_semaphore=dQ_semaphore_tensor,
mdK_semaphore=dK_semaphore_tensor,
mdV_semaphore=dV_semaphore_tensor,
Expand All @@ -941,6 +960,8 @@ def _flash_attn_bwd(
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
window_size_left=window_size_left,
window_size_right=window_size_right,
mdQ_semaphore=dQ_semaphore_tensor,
mdK_semaphore=dK_semaphore_tensor,
mdV_semaphore=dV_semaphore_tensor,
Expand Down Expand Up @@ -1103,6 +1124,8 @@ def backward(ctx, dout, *args):
ctx.softmax_scale,
ctx.causal,
ctx.softcap,
window_size_left=ctx.window_size[0],
window_size_right=ctx.window_size[1],
deterministic=ctx.deterministic,
)
return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
Expand Down
39 changes: 26 additions & 13 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ def apply_mask(
)
if const_expr(self.window_size_right is not None):
col_limit_right = row_idx + local_row_offset_right
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
else:
col_limit_right = self.tile_n
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
col_limit_left = (
row_idx + local_row_offset_left
if const_expr(self.window_size_left is not None)
Expand Down Expand Up @@ -411,10 +411,10 @@ def apply_mask_sm100(
)
if const_expr(self.window_size_right is not None):
col_limit_right = row_idx + local_row_offset_right
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
else:
col_limit_right = self.tile_n
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
col_limit_left = (
row_idx + local_row_offset_left
if const_expr(self.window_size_left is not None)
Expand Down Expand Up @@ -447,28 +447,27 @@ def apply_mask_sm100_transposed(
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
ROW = 0 if const_expr(not self.swap_AB) else 1
COL = 1 if const_expr(not self.swap_AB) else 0
assert t0ScS_t2r[0][COL] == 0, "col0 == 0"
thr_col_offset = tScS_t2r[0][COL]
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
if const_expr(not mask_causal and not mask_local):
if const_expr(mask_seqlen):
if t0ScS_t2r[0][COL] >= seqlenk_col_limit:
if seqlenk_col_limit <= 0:
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
acc_S[i] = -cutlass.Float32.inf
else: # Causal or local
thr_row_offset = tScS_t2r[0][ROW]
causal_row_offset = (
seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset
)
seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
causal_offset = seqlenq_row_limit - seqlenk_col_limit
if const_expr(mask_causal):
col0 = t0ScS_t2r[0][COL]
row_limit_top = col0 - causal_row_offset
# tidx = cute.arch.thread_idx()[0] % 256
# if tidx < 32:
# cute.printf("tidx = {}, {} {}, {} {}, col0 = {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1], col0)
# cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1])
row_limit_top = causal_offset
if const_expr(mask_seqlen):
# If col is beyond the column limit, we want to mask out the entire
# column, by setting row limit to be self.tile_m.
if t0ScS_t2r[0][COL] >= seqlenk_col_limit:
if seqlenk_col_limit <= 0:
row_limit_top = self.tile_m
r2p = True
if const_expr(not r2p):
Expand All @@ -480,4 +479,18 @@ def apply_mask_sm100_transposed(
num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
mask_r2p_transposed(acc_S, row_limit_top, num_rep)
else:
assert False, "Local masking isn't supported yet"
if const_expr(self.window_size_right is not None):
row_limit_top = causal_offset - self.window_size_right
else:
row_limit_top = 0
if const_expr(self.window_size_left is not None):
row_limit_bot = causal_offset + self.window_size_left
if const_expr(mask_seqlen):
if seqlenk_col_limit <= 0:
row_limit_top = self.tile_m
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
row_idx = t0ScS_t2r[i][ROW]
local_mask = row_idx < row_limit_top
if const_expr(self.window_size_left is not None):
local_mask |= row_idx > row_limit_bot
acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
6 changes: 5 additions & 1 deletion flash_attn/cute/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,12 @@ def construct_local_mask(
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
if window_size[1] is None:
local_mask_left = col_idx > sk
else:
local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk)
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
local_mask_left,
torch.logical_and(
col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length
),
Expand Down
50 changes: 38 additions & 12 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@


DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"

TEST_BWD_ONLY = False
VERBOSE = True

# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
Expand All @@ -43,8 +44,8 @@
@pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
# @pytest.mark.parametrize("local_enum", [0])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
Expand Down Expand Up @@ -92,16 +93,17 @@ def test_flash_attn_output(
seqlen_k,
d,
causal,
local,
local_enum,
softcap,
deterministic,
has_qv,
has_learnable_sink,
mha_type,
dtype,
):
# if (causal or local) and seqlen_k < seqlen_q:
# pytest.skip("Causal attention requires seqlen_k >= seqlen_q")
local = local_enum > 0
if local and causal:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
Expand All @@ -115,7 +117,7 @@ def test_flash_attn_output(
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
if dtype == torch.float8_e4m3fn:
if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY:
dv_vals = [d]
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0]
attention_chunk_vals = [0]
Expand Down Expand Up @@ -157,6 +159,12 @@ def test_flash_attn_output(
window_size = (
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
)
if local_enum == 2:
window_size = (None, -window_size[1])
elif local_enum == 3:
window_size = (-window_size[0], None)
if local:
print("window size = ", window_size)
# window_size = (-1, -1) if not local else (16, 0)
if has_learnable_sink:
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
Expand Down Expand Up @@ -228,7 +236,7 @@ def test_flash_attn_output(
# pack_gqa_vals = [False, True, None]
# SplitKV is not supported for hdim >= 192
pack_gqa_vals = [False]
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
out, lse = flash_attn_func(
q,
Expand All @@ -241,8 +249,9 @@ def test_flash_attn_output(
# attention_chunk=attention_chunk,
softcap=softcap,
learnable_sink=learnable_sink,
# pack_gqa=pack_gqa,
pack_gqa=pack_gqa,
num_splits=num_splits,
deterministic=deterministic,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
Expand All @@ -262,12 +271,9 @@ def test_flash_attn_output(
and not dv > 256
and not attention_chunk != 0
and softcap == 0.0
and not local
and dv == d
and learnable_sink is None
# and mha_type == "mha"
# and False
and not ((causal or local) and seqlen_k < seqlen_q)
):
g = torch.randn_like(out)
# do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
Expand Down Expand Up @@ -301,6 +307,26 @@ def test_flash_attn_output(
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

if VERBOSE:
diff_dq = (dq - dq_ref).abs()
max_idx = diff_dq.argmax()
coords = torch.unravel_index(max_idx, diff_dq.shape)
print(f"dQ max diff: {diff_dq.max().item()}")
print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}")

diff_dk = (dk - dk_ref).abs()
max_idx = diff_dk.argmax()
coords = torch.unravel_index(max_idx, diff_dk.shape)
print(f"dK max diff: {diff_dk.max().item()}")
print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}")

diff_dv = (dv - dv_ref).abs()
max_idx = diff_dv.argmax()
coords = torch.unravel_index(max_idx, diff_dv.shape)
print(f"dV max diff: {diff_dv.max().item()}")
print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}")

# breakpoint()
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
Expand Down
Loading