Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 53 additions & 41 deletions examples/flash_attention/example_gqa_fwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,32 @@ def attention_ref(
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1]
scale = (1.0 / dim)**0.5
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
b, T, Hq, D = q.shape
S = k.shape[1]
scale = (1.0 / D)**0.5
k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
left, right = window_size
left = S if left is None or left < 0 else int(left)
right = S if right is None or right < 0 else int(right)
t_idx = torch.arange(T, device=scores.device)[:, None]
s_idx = torch.arange(S, device=scores.device)[None, :]
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
visible_mask = visible_mask & k_keep
neg_inf = torch.finfo(scores.dtype).min
scores = scores * scale
scores = scores.masked_fill(~visible_mask, neg_inf)
attention = torch.softmax(scores, dim=-1).to(v.dtype)

if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
attention = attention.masked_fill(~q_keep, 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


Expand Down Expand Up @@ -91,60 +102,63 @@ def main(
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
})

batch_idx = bz
head_idx = by
kv_head_idx = head_idx // groups

q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
kv_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]

q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
v_current_seqlen = v_end_idx - v_start_idx
kv_current_seqlen = k_end_idx - kv_start_idx

T.copy(
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
Q_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i >= q_current_seqlen:
Q_shared[i, d] = 0

T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = T.ceildiv(k_current_seqlen, block_N)
loop_range = (
T.min(
T.ceildiv(q_current_seqlen +
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal else T.ceildiv(kv_current_seqlen, block_N))

Comment on lines +130 to 135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix causal tile loop_range regression

For is_causal=True the new loop_range ends up as ceildiv(kv_current_seqlen, block_N) for every tile (e.g., q=kv=2048, block_M=block_N=128 ⇒ the very first tile now iterates all 16 K tiles instead of just the first one). We lose the triangular work reduction, so causal runs pay the full non-causal cost—latency blows up roughly ×16 on long sequences. Replace the sum with a clamp to the end of the current query tile.

-            loop_range = (
-                T.min(
-                    T.ceildiv(q_current_seqlen +
-                              (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
-                if is_causal else T.ceildiv(kv_current_seqlen, block_N))
+            loop_range = (
+                T.min(
+                    T.ceildiv(T.min(q_current_seqlen, (bx + 1) * block_M), block_N),
+                    T.ceildiv(kv_current_seqlen, block_N))
+                if is_causal else T.ceildiv(kv_current_seqlen, block_N))
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 130-135, the
causal branch builds loop_range by adding (bx+1)*block_M to q_current_seqlen
which causes every causal tile to iterate up to the full KV length; replace the
sum with a clamp to the end of the current query tile so the loop range uses the
minimum of q_current_seqlen and (bx+1)*block_M before taking ceildiv and then
min that with the KV ceildiv; implement this by computing the tile-end =
min(q_current_seqlen, (bx+1)*block_M), using ceildiv(tile-end, block_N), and
then min(...) with ceildiv(kv_current_seqlen, block_N) for the causal case
(non-causal case unchanged).

for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N,
K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], K_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= k_current_seqlen:
K_shared[i, d] = 0

if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
acc_s[i,
j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= kv_current_seqlen), -1e9, 0)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
k * block_N + j >= kv_current_seqlen), -1e9,
0)

T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)

for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])

for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
Expand All @@ -158,11 +172,8 @@ def main(
acc_o[i, j] *= scores_scale[i]

T.copy(
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N,
V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], V_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= v_current_seqlen:
V_shared[i, d] = 0

T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

Expand Down Expand Up @@ -191,8 +202,7 @@ def main(batch: int = 1,

tilelang.testing.set_random_seed(0)

causal = False
if causal:
if is_causal:
total_flops *= 0.5

tilelang.testing.set_random_seed(0)
Expand All @@ -201,9 +211,9 @@ def main(batch: int = 1,
device = torch.device("cuda")

head_kv = heads // groups
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)

query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
Expand Down Expand Up @@ -236,10 +246,10 @@ def main(batch: int = 1,
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
block_M=128,
block_N=128,
num_stages=2,
threads=256)

out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)
Expand All @@ -255,7 +265,9 @@ def main(batch: int = 1,
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
latency = do_bench(
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q))
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q),
_n_warmup=5,
_n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))

Expand Down
Loading