diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 7b7eeced7cd..32e7743dcd4 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -68,6 +68,24 @@ def _get_device_arch(): major, minor = torch.cuda.get_device_capability() return major * 10 + minor + +def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: + """Validate head dimension constraints based on compute capability.""" + is_deepseek_shape = head_dim == 192 and head_dim_v == 128 + is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 + + if compute_capability == 9: + assert is_standard_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. " + f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}." + ) + elif compute_capability in [10, 11]: + assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " + f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." + ) + + def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -217,11 +235,11 @@ def _flash_attn_fwd( learnable_sink, ) ), "inputs must be on CUDA device" + arch = _get_device_arch() if _arch is None else _arch + assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" - assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() - assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" - assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) if softcap == 0.0: @@ -253,10 +271,6 @@ def _flash_attn_fwd( _validate_tensor(lse, "lse", lse_shape, torch.float32, device) dtype = torch2cute_dtype_map[q.dtype] - arch = _get_device_arch() if _arch is None else _arch - - assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" - use_block_sparsity = block_sparse_tensors is not None if mask_mod is None: @@ -748,10 +762,8 @@ def _flash_attn_bwd( t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" - assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() - assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" - assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) qhead_per_kvhead = num_head // num_head_kv diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index b48964461ad..228a506196e 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -4,6 +4,7 @@ import itertools import os import random +import re import pytest import torch @@ -1582,3 +1583,17 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( "Output should be the same regardless of return_lse" ) + + +@pytest.mark.parametrize("head_dim", [4, 144, 256]) +def test_flash_attn_invalid_head_dim(head_dim): + device = "cuda" + dtype = torch.bfloat16 + batch_size, seqlen, nheads = 1, 64, 4 + + q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + + with pytest.raises(AssertionError, match=re.escape(f"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM100/SM110.")): + flash_attn_func(q, k, v)