Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 21 additions & 56 deletions examples/flash_attention/example_gqa_fwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,10 @@
import tilelang
import tilelang.language as T
import tilelang.testing
from einops import rearrange, repeat
from tilelang.profiler import do_bench
from varlen_utils import generate_random_padding_mask, generate_qkv


def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1),
upcast=True,
):
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
b, T, Hq, D = q.shape
S = k.shape[1]
scale = (1.0 / D) ** 0.5
k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
left, right = window_size
left = S if left is None or left < 0 else int(left)
right = S if right is None or right < 0 else int(right)
t_idx = torch.arange(T, device=scores.device)[:, None]
s_idx = torch.arange(S, device=scores.device)[None, :]
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
if key_padding_mask is not None:
k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
visible_mask = visible_mask & k_keep
neg_inf = torch.finfo(scores.dtype).min
scores = scores * scale
scores = scores.masked_fill(~visible_mask, neg_inf)
attention = torch.softmax(scores, dim=-1).to(v.dtype)
if query_padding_mask is not None:
q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
attention = attention.masked_fill(~q_keep, 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


@tilelang.jit(
out_idx=[6],
pass_configs={
Expand Down Expand Up @@ -110,8 +65,10 @@ def main(
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))

offset = kv_current_seqlen - q_current_seqlen # always align on the right
max_visible_k_idx = offset + (bx + 1) * block_M
loop_range = (
T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal
else T.ceildiv(kv_current_seqlen, block_N)
)
Comment on lines +68 to 74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Verify loop_range behavior when max_visible_k_idx is negative.

When q_current_seqlen > kv_current_seqlen (offset < 0) and bx is small, max_visible_k_idx can be negative or very small. The behavior of T.ceildiv(max_visible_k_idx, block_N) with a negative numerator may be implementation-dependent and could lead to:

  1. loop_range = 0, causing the loop at line 76 not to execute
  2. logsum[i] remaining 0 for some positions
  3. Potential division by zero at line 119 if the guard doesn't catch all cases

Consider adding an explicit guard to ensure loop_range >= 0 and that positions with loop_range = 0 are handled correctly:

🔎 Suggested fix to add explicit bounds
 offset = kv_current_seqlen - q_current_seqlen  # always align on the right
 max_visible_k_idx = offset + (bx + 1) * block_M
 loop_range = (
-    T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N))
+    T.max(0, T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)))
     if is_causal
     else T.ceildiv(kv_current_seqlen, block_N)
 )

Expand All @@ -122,7 +79,7 @@ def main(
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i < k * block_N + j)
(bx * block_M + i + offset < k * block_N + j)
or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen),
-1e9,
0,
Expand Down Expand Up @@ -158,9 +115,10 @@ def main(
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
# When sq > skv, some tokens can see nothing
acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i]
Comment on lines 117 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add guard for zero logsum to prevent division by zero.

While the condition bx * block_M + i + offset < 0 correctly handles tokens before the KV sequence start, there may be edge cases where offset >= 0 but the loop didn't execute (e.g., due to boundary conditions), leaving logsum[i] = 0. This would cause division by zero.

🔎 Proposed fix to add logsum guard
 for i, j in T.Parallel(block_M, dim):
     # When sq > skv, some tokens can see nothing
-    acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i]
+    acc_o[i, j] = 0 if (is_causal and bx * block_M + i + offset < 0) or logsum[i] == 0 else acc_o[i, j] / logsum[i]
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 117-119, the
final assignment divides acc_o[i, j] by logsum[i] without guarding against
logsum being zero; update the conditional so that if logsum[i] is zero (or below
a tiny epsilon) you set acc_o[i, j] = 0 (or the same branch as tokens that see
nothing), otherwise perform the division; implement the check in the same
if/else expression to avoid division-by-zero and consider using a small epsilon
for numerical safety.


T.copy(acc_o, O_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]
Expand Down Expand Up @@ -218,15 +176,22 @@ def main(
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)

out_ref, _ = attention_ref(
q,
k,
v,
query_padding_mask=query_padding_mask,
key_padding_mask=key_padding_mask,
import flash_attn

fa_out_unpad = flash_attn.flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
causal=is_causal,
)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
fa_out = output_pad_fn(fa_out_unpad)
torch.testing.assert_close(out, fa_out, rtol=1e-2, atol=1e-2)

print("All checks passed.✅")
latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency))
Expand Down
123 changes: 27 additions & 96 deletions examples/flash_attention/example_mha_fwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,68 +8,10 @@
from tilelang.autotuner import set_autotune_inputs, autotune

import torch
from einops import rearrange, repeat
from varlen_utils import generate_random_padding_mask, generate_qkv
import itertools


def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1]
scale = (1.0 / dim) ** 0.5 # log2(e)
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
# scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0)
scores = scores * scale
attention = torch.softmax(scores, dim=-1).to(v.dtype)

# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


def get_configs():
iter_params = dict(block_M=[64, 128], block_N=[64, 128], num_stages=[0, 1, 2, 3], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
Expand Down Expand Up @@ -120,15 +62,12 @@ def main(
head_idx = by

q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
kv_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
kv_end_idx = cu_seqlens_k[batch_idx + 1]

q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
v_current_seqlen = v_end_idx - v_start_idx
kv_current_seqlen = kv_end_idx - kv_start_idx

T.copy(
Q_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :], Q_shared
Expand All @@ -138,25 +77,30 @@ def main(
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = T.ceildiv(k_current_seqlen, block_N)
offset = kv_current_seqlen - q_current_seqlen # always align on the right
loop_range = (
T.min(T.ceildiv(offset + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal
else T.ceildiv(kv_current_seqlen, block_N)
)

for k in T.Pipelined(loop_range, num_stages=num_stages):
# Q * K
T.copy(
K_unpad[k_start_idx + k * block_N : k_start_idx + k * block_N + block_N, head_idx, :], K_shared
K_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], K_shared
) # OOB positions will be handled below
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= k * block_N + j)
and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype),
(bx * block_M + i + offset < k * block_N + j)
or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen),
-1e9,
0,
)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0
(bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0
)

T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Expand Down Expand Up @@ -190,34 +134,34 @@ def main(

# V * softmax(Q * K)
T.copy(
V_unpad[v_start_idx + k * block_N : v_start_idx + k * block_N + block_N, head_idx, :], V_shared
V_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], V_shared
) # OOB positions' weights are 0

T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
# When sq > skv, some tokens can see nothing
acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i]

T.copy(acc_o, O_shared)
T.copy(
O_shared, Output_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :]
) # TMA will handle OOB
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]

return main


def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, tune: bool = False):
def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False, tune: bool = False):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul

tilelang.testing.set_random_seed(0)

causal = False
if causal:
total_flops *= 0.5

dtype = torch.float16
device = torch.device("cuda")
window_size = (-1, -1)

q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
Expand All @@ -237,12 +181,11 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, t
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
_,
_,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)

UQ = q_unpad.shape[0] # unpadded query length
UK = k_unpad.shape[0] # unpadded key length
UKV = k_unpad.shape[0] # unpadded query key length

if tune:
Expand All @@ -255,16 +198,6 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, t
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)

out_ref, _ = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
causal=causal,
)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)

import flash_attn

fla_out_unpad = flash_attn.flash_attn_varlen_func(
Expand Down Expand Up @@ -296,16 +229,14 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, t
print(f"FA2: {total_flops / t * 1e-9} TFlops")


def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
tilelang.testing.set_random_seed(0)
causal = False
if causal:
total_flops *= 0.5
dtype = torch.float16
device = torch.device("cuda")
window_size = (-1, -1)
q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
Expand All @@ -327,7 +258,6 @@ def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, di
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0]
UK = k_unpad.shape[0]
UKV = k_unpad.shape[0]
kernel = flashattn(batch, UQ, UKV, heads, dim, causal, block_M=128, block_N=128, num_stages=2, threads=256)

Expand All @@ -345,7 +275,8 @@ def run_kernel_only():
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_len", type=int, default=2048, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", default=False, help="causal attention")
parser.add_argument("--tune", action="store_true", default=False, help="tune the kernel")

args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.tune)
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune)
10 changes: 9 additions & 1 deletion examples/flash_attention/test_example_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import example_mha_bwd_bshd_wgmma_pipelined
import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen
import example_gqa_fwd_varlen


@tilelang.testing.requires_cuda
Expand Down Expand Up @@ -94,7 +95,14 @@ def test_example_mha_fwd_bshd():

@tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)
example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=False)
example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=True)


@tilelang.testing.requires_cuda
def test_example_gqa_fwd_varlen():
example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=False)
example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=True)


if __name__ == "__main__":
Expand Down
Loading