diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 78a8e7c2cc4..0527e241d35 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -1246,7 +1246,7 @@ def test_flash3_bw_compatibility() -> None: "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " - "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + "-> (Tensor, Tensor, Tensor, Tensor, Tensor)" )) assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, "