Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flash_attn/cute/flash_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ def __call__(
self._check_type(*(t.element_type if t is not None else None
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1])
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s for s in t.stride[:-1]), t.stride[-1])
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)]
self.varlen_q = (mCuSeqlensQ is not None)
self._setup_attributes()
Expand Down
6 changes: 5 additions & 1 deletion flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,12 @@ def __call__(
)

# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
*(
cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s
for s in t.stride[:-1]
),
t.stride[-1],
)
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
Expand Down
16 changes: 14 additions & 2 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,14 @@ def __call__(
self._setup_attributes()
SharedStorage = self._get_shared_storage_cls()
# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
*(
cute.assume(s, divby=128 // t.element_type.width)
if s != 0
else s
for s in t.stride[:-1]
),
t.stride[-1],
)
mQ, mK, mV, mO = [
Expand Down Expand Up @@ -1296,8 +1302,14 @@ def __call__(
)

# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
*(
cute.assume(s, divby=128 // t.element_type.width)
if s != 0
else s
for s in t.stride[:-1]
),
t.stride[-1],
)

Expand Down
114 changes: 114 additions & 0 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,5 +989,119 @@ def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, to
assert err_no_broadcast_dq < 0.1, f"No-broadcast dQ error too large: {err_no_broadcast_dq:.2e}"


def test_gqa_expand_stride_zero_bug():
"""Test that GQA with expand()-created K/V tensors works correctly.

This is a regression test for bugs with expand()-created tensors:

Forward bug: cute.assume() fails when tensor strides are Python int 0
(from expand()) instead of MLIR values.
Error: AttributeError: 'int' object has no attribute 'type'

Backward bug: mark_layout_dynamic fails with expanded tensors.
Error: RuntimeError: Expected strides[leading_dim] == 1, but got N.

Trigger: expand() + transpose() creates stride=0 dimensions (GQA pattern).
"""
torch.manual_seed(42)

batch_size = 1
seqlen = 2048
headdim = 128
n_heads = 4
n_kv_heads = 1
dtype = torch.bfloat16
device = "cuda"

q = torch.randn(batch_size, seqlen, n_heads, headdim, device=device, dtype=dtype)
k_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype)
v_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype)

k = k_orig.expand(batch_size, seqlen, n_heads, headdim)
v = v_orig.expand(batch_size, seqlen, n_heads, headdim)

assert k.stride()[2] == 0, "K should have stride=0 in head dim from expand()"
assert v.stride()[2] == 0, "V should have stride=0 in head dim from expand()"

out = torch.empty_like(q)
lse = torch.empty(batch_size, n_heads, seqlen, device=device, dtype=torch.float32)
softmax_scale = 1.0 / math.sqrt(headdim)

out_tuple = _flash_attn_fwd(
q=q, k=k, v=v, out=out, lse=lse,
softmax_scale=softmax_scale,
causal=True,
m_block_size=128, n_block_size=128,
return_lse=True,
)
out_fwd, lse_fwd = out_tuple[0], out_tuple[1]

assert not torch.isnan(out_fwd).any(), "Forward output contains NaN"
assert torch.isfinite(out_fwd).all(), "Forward output contains non-finite values"

tensors_for_ref = {"q": q, "k": k, "v": v}
tensors_fp32 = {"q": q.float(), "k": k.float(), "v": v.float()}

def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

out_ref = compute_reference_flex_attn(tensors_for_ref, causal_mask)
out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, causal_mask)

fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()
rtol = 2
pt_error = (out_ref - out_ref_fp32).abs().max().item()
cute_error = (out_fwd - out_ref_fp32).abs().max().item()

print(f"\nGQA expand stride=0 test:")
print(f" Forward: kernel err={cute_error:.2e}, ref err={pt_error:.2e}, atol={fwd_atol:.2e}")
assert cute_error <= rtol * pt_error + fwd_atol, (
f"Forward error {cute_error:.2e} exceeds {rtol}x ref error {pt_error:.2e} + {fwd_atol:.2e}"
)

grad_out = torch.randn_like(out_fwd)
dq, dk, dv = _flash_attn_bwd(
q=q, k=k, v=v, out=out_fwd, dout=grad_out, lse=lse_fwd,
softmax_scale=softmax_scale,
causal=True,
m_block_size=128, n_block_size=128,
)

assert not torch.isnan(dq).any(), "dQ contains NaN"
assert not torch.isnan(dk).any(), "dK contains NaN"
assert not torch.isnan(dv).any(), "dV contains NaN"

flex_block_mask = create_block_mask(
causal_mask, batch_size, n_heads, seqlen, seqlen,
device=device, BLOCK_SIZE=(128, 128),
)
_, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32)

bwd_rtol = 2
bwd_atol_floor = 1e-5

dq_atol = max(bwd_atol_floor, 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item())
dk_atol = max(bwd_atol_floor, 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item())
dv_atol = max(bwd_atol_floor, 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item())

_, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out)

pt_dq_err = (dq_pt - dq_ref.to(dtype)).abs().max().item()
pt_dk_err = (dk_pt - dk_ref.to(dtype)).abs().max().item()
pt_dv_err = (dv_pt - dv_ref.to(dtype)).abs().max().item()

cute_dq_err = (dq - dq_ref.to(dtype)).abs().max().item()
cute_dk_err = (dk - dk_ref.to(dtype)).abs().max().item()
cute_dv_err = (dv - dv_ref.to(dtype)).abs().max().item()

print(f" Backward dQ: kernel err={cute_dq_err:.2e}, ref err={pt_dq_err:.2e}, atol={dq_atol:.2e}")
print(f" Backward dK: kernel err={cute_dk_err:.2e}, ref err={pt_dk_err:.2e}, atol={dk_atol:.2e}")
print(f" Backward dV: kernel err={cute_dv_err:.2e}, ref err={pt_dv_err:.2e}, atol={dv_atol:.2e}")

assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}"
assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}"
assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}"


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])