Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x


def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
assert t.is_cuda, f"{name} must be on CUDA"


torch2cute_dtype_map = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
Expand Down Expand Up @@ -211,17 +218,7 @@ def _flash_attn_fwd(
*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device
)
else:
expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v)
assert out.shape == expected_out_shape, (
f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}"
)
assert out.dtype == out_torch_dtype, (
f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}"
)
assert out.device == device, (
f"out tensor device {out.device} does not match input device {device}"
)
assert out.is_cuda, "out tensor must be on CUDA device"
_validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device)

if lse is None:
lse = (
Expand All @@ -230,16 +227,7 @@ def _flash_attn_fwd(
else None
)
elif lse is not None:
assert lse.shape == lse_shape, (
f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}"
)
assert lse.dtype == torch.float32, (
f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32"
)
assert lse.device == device, (
f"lse tensor device {lse.device} does not match input device {device}"
)
assert lse.is_cuda, "lse tensor must be on CUDA device"
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)

dtype = torch2cute_dtype_map[q.dtype]
(
Expand Down Expand Up @@ -561,6 +549,9 @@ def _flash_attn_bwd(
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
deterministic: bool = False,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
compute_capability = torch.cuda.get_device_capability()[0]
assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x"
Expand Down Expand Up @@ -674,10 +665,22 @@ def _flash_attn_bwd(
assert deterministic is False, "bwd deterministic only supported for sm100 for now"

device = q.device
# TODO: check if this is the right rounding
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
out_torch_dtype = q.dtype

if dq is None:
dq = torch.empty_like(q)
else:
_validate_tensor(dq, "dq", q.shape, out_torch_dtype, device)

if dk is None:
dk = torch.empty_like(k)
else:
_validate_tensor(dk, "dk", k.shape, out_torch_dtype, device)

if dv is None:
dv = torch.empty_like(v)
else:
_validate_tensor(dv, "dv", v.shape, out_torch_dtype, device)

head_dim_rounded = (head_dim + 32 - 1) // 32 * 32

Expand Down
36 changes: 36 additions & 0 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,42 @@ def test_flash_attn_kvcache(
).abs().mean().item()


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)])
def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype):
from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd

device = "cuda"
torch.random.manual_seed(42)
batch_size = 2
nheads = 4

q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)

out, lse = _flash_attn_fwd(q, k, v, causal=causal, return_lse=True)
dout = torch.randn_like(out)

dq_ref, dk_ref, dv_ref = _flash_attn_bwd(q, k, v, out, dout, lse, causal=causal)

dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dq_out, dk_out, dv_out = _flash_attn_bwd(
q, k, v, out, dout, lse, causal=causal, dq=dq, dk=dk, dv=dv
)

assert dq_out is dq
assert dk_out is dk
assert dv_out is dv
assert torch.allclose(dq, dq_ref, atol=1e-5, rtol=1e-5)
assert torch.allclose(dk, dk_ref, atol=1e-5, rtol=1e-5)
assert torch.allclose(dv, dv_ref, atol=1e-5, rtol=1e-5)


def _generate_block_kvcache(
seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref
):
Expand Down