Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
self.use_2cta_instrs = bool(
use_2cta_instrs
and cluster_size == 2
and not is_local
and score_mod is None
and score_mod_bwd is None
and mask_mod is None
Expand Down Expand Up @@ -928,10 +927,6 @@ class SharedStorage:
"2-CTA mode does not support block sparsity. "
"Please create kernel with use_2cta_instrs=False for block sparse attention."
)
assert window_size_left is None and window_size_right is None, (
"2-CTA mode does not support window attention. "
"Please create kernel with use_2cta_instrs=False for window attention."
)
# 2-CTA: 231424 and 1-CTA: 232448
# print("SMEM: ", self.shared_storage.size_in_bytes())
if const_expr(self.use_block_sparsity or aux_tensors is not None):
Expand Down
3 changes: 1 addition & 2 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,8 +1060,7 @@ def _flash_attn_bwd(
AtomLayoutMdQ = 1
AtomLayoutNdKV = 1
disable_2cta = (
local
or score_mod is not None
score_mod is not None
or score_mod_bwd is not None
or mask_mod is not None
)
Expand Down
6 changes: 1 addition & 5 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ def test_flash_attn_output(
# and False
and not ((causal or local) and seqlen_k < seqlen_q)
):
if d == 192 and local:
pytest.xfail("hdim 192 backward: local attention not supported yet")
if d > 192 and IS_SM90:
pytest.xfail("hdim > 192 backward: SM90 not supported yet")
if d != dv and mha_type != "mha" and IS_SM90:
Expand Down Expand Up @@ -405,7 +403,7 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128])
# @pytest.mark.parametrize("d", [128, 192])
@pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize("d", [64, 128, 192])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
Expand Down Expand Up @@ -737,8 +735,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
and not has_learnable_sink
# and False
):
if d == 192 and local:
pytest.xfail("hdim 192 backward: local attention not supported yet")
if d > 192 and IS_SM90:
pytest.xfail("hdim > 192 backward: SM90 not supported yet")
if d != dv and mha_type != "mha" and IS_SM90:
Expand Down
4 changes: 0 additions & 4 deletions tests/cute/test_flash_attn_race_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,6 @@ def test_flash_attn_output(
pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)")
if IS_SM90 and local:
pytest.xfail("SM90 backward: local attention not supported yet")
if d == 192 and local:
pytest.xfail("hdim 192 backward: local attention not supported yet")
g = torch.randn_like(out)
# do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
Expand Down Expand Up @@ -658,8 +656,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
and not is_sm90
# and False
):
if d == 192 and local:
pytest.xfail("hdim 192 backward: local attention not supported yet")
g_unpad = torch.randn_like(out_unpad)
# do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
# import flash_attn_3_cuda
Expand Down