diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 3a2e821ed00..e06cd811fc6 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -84,7 +84,6 @@ def __init__( self.use_2cta_instrs = bool( use_2cta_instrs and cluster_size == 2 - and not is_local and score_mod is None and score_mod_bwd is None and mask_mod is None @@ -928,10 +927,6 @@ class SharedStorage: "2-CTA mode does not support block sparsity. " "Please create kernel with use_2cta_instrs=False for block sparse attention." ) - assert window_size_left is None and window_size_right is None, ( - "2-CTA mode does not support window attention. " - "Please create kernel with use_2cta_instrs=False for window attention." - ) # 2-CTA: 231424 and 1-CTA: 232448 # print("SMEM: ", self.shared_storage.size_in_bytes()) if const_expr(self.use_block_sparsity or aux_tensors is not None): diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 561332659a8..5aafacf03b0 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1060,8 +1060,7 @@ def _flash_attn_bwd( AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 disable_2cta = ( - local - or score_mod is not None + score_mod is not None or score_mod_bwd is not None or mask_mod is not None ) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 2eacce6a150..b6264702e40 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -297,8 +297,6 @@ def test_flash_attn_output( # and False and not ((causal or local) and seqlen_k < seqlen_q) ): - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") if d > 192 and IS_SM90: pytest.xfail("hdim > 192 backward: SM90 not supported yet") if d != dv and mha_type != "mha" and IS_SM90: @@ -405,7 +403,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -737,8 +735,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not has_learnable_sink # and False ): - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") if d > 192 and IS_SM90: pytest.xfail("hdim > 192 backward: SM90 not supported yet") if d != dv and mha_type != "mha" and IS_SM90: diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 4e36733dc7f..a9b8799f4c1 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -255,8 +255,6 @@ def test_flash_attn_output( pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") if IS_SM90 and local: pytest.xfail("SM90 backward: local attention not supported yet") - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) @@ -658,8 +656,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not is_sm90 # and False ): - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") g_unpad = torch.randn_like(out_unpad) # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda