Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ def _get_device_arch():
major, minor = torch.cuda.get_device_capability()
return major * 10 + minor


def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None:
"""Validate head dimension constraints based on compute capability."""
is_deepseek_shape = head_dim == 192 and head_dim_v == 128
is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128

if compute_capability == 9:
assert is_standard_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, (
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. "
f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}."
)
elif compute_capability in [10, 11]:
assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, (
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. "
f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek."
)


def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

Expand Down Expand Up @@ -217,11 +235,11 @@ def _flash_attn_fwd(
learnable_sink,
)
), "inputs must be on CUDA device"
arch = _get_device_arch() if _arch is None else _arch
assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
_validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
if softcap == 0.0:
Expand Down Expand Up @@ -253,10 +271,6 @@ def _flash_attn_fwd(
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)

dtype = torch2cute_dtype_map[q.dtype]
arch = _get_device_arch() if _arch is None else _arch

assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"

use_block_sparsity = block_sparse_tensors is not None

if mask_mod is None:
Expand Down Expand Up @@ -748,10 +762,8 @@ def _flash_attn_bwd(
t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
), "inputs must be on CUDA device"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
_validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
qhead_per_kvhead = num_head // num_head_kv
Expand Down
15 changes: 15 additions & 0 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import os
import random
import re

import pytest
import torch
Expand Down Expand Up @@ -1582,3 +1583,17 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), (
"Output should be the same regardless of return_lse"
)


@pytest.mark.parametrize("head_dim", [4, 144, 256])
def test_flash_attn_invalid_head_dim(head_dim):
device = "cuda"
dtype = torch.bfloat16
batch_size, seqlen, nheads = 1, 64, 4

q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)

with pytest.raises(AssertionError, match=re.escape(f"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM100/SM110.")):
flash_attn_func(q, k, v)