diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index 37e81ebb3..db16e1586 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -24,21 +24,32 @@ def attention_ref( dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() - dim = q.shape[-1] - scale = (1.0 / dim)**0.5 - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + b, T, Hq, D = q.shape + S = k.shape[1] + scale = (1.0 / D)**0.5 + k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) scores = torch.einsum("bthd,bshd->bhts", q, k) + left, right = window_size + left = S if left is None or left < 0 else int(left) + right = S if right is None or right < 0 else int(right) + t_idx = torch.arange(T, device=scores.device)[:, None] + s_idx = torch.arange(S, device=scores.device)[None, :] + visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right)) + visible_mask = visible_ts.unsqueeze(0).unsqueeze(0) if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s") + visible_mask = visible_mask & k_keep + neg_inf = torch.finfo(scores.dtype).min scores = scores * scale + scores = scores.masked_fill(~visible_mask, neg_inf) attention = torch.softmax(scores, dim=-1).to(v.dtype) - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1") + attention = attention.masked_fill(~q_keep, 0.0) output = torch.einsum("bhts,bshd->bthd", attention, v) if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -91,53 +102,53 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + }) + batch_idx = bz head_idx = by kv_head_idx = head_idx // groups q_start_idx = cu_seqlens_q[batch_idx] - k_start_idx = cu_seqlens_k[batch_idx] - v_start_idx = cu_seqlens_k[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] q_end_idx = cu_seqlens_q[batch_idx + 1] k_end_idx = cu_seqlens_k[batch_idx + 1] - v_end_idx = cu_seqlens_k[batch_idx + 1] q_current_seqlen = q_end_idx - q_start_idx - k_current_seqlen = k_end_idx - k_start_idx - v_current_seqlen = v_end_idx - v_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx T.copy( Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) - for i, d in T.Parallel(block_M, dim): - if bx * block_M + i >= q_current_seqlen: - Q_shared[i, d] = 0 T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(k_current_seqlen, block_N) + loop_range = ( + T.min( + T.ceildiv(q_current_seqlen + + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal else T.ceildiv(kv_current_seqlen, block_N)) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N, + K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared) - for i, d in T.Parallel(block_N, dim): - if k * block_N + i >= k_current_seqlen: - K_shared[i, d] = 0 if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + acc_s[i, + j] = T.if_then_else((bx * block_M + i < k * block_N + j) or + (bx * block_M + i >= q_current_seqlen or + k * block_N + j >= kv_current_seqlen), -1e9, 0) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or - k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), 0) + k * block_N + j >= kv_current_seqlen), -1e9, + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -145,6 +156,9 @@ def main( T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): @@ -158,11 +172,8 @@ def main( acc_o[i, j] *= scores_scale[i] T.copy( - V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N, + V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared) - for i, d in T.Parallel(block_N, dim): - if k * block_N + i >= v_current_seqlen: - V_shared[i, d] = 0 T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) @@ -191,8 +202,7 @@ def main(batch: int = 1, tilelang.testing.set_random_seed(0) - causal = False - if causal: + if is_causal: total_flops *= 0.5 tilelang.testing.set_random_seed(0) @@ -201,9 +211,9 @@ def main(batch: int = 1, device = torch.device("cuda") head_kv = heads // groups - q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True) - k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) - v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") @@ -236,10 +246,10 @@ def main(batch: int = 1, heads, dim, is_causal, - block_M=64, - block_N=64, - num_stages=1, - threads=128) + block_M=128, + block_N=128, + num_stages=2, + threads=256) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) @@ -255,7 +265,9 @@ def main(batch: int = 1, torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) print("All checks passed.✅") latency = do_bench( - lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) + lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), + _n_warmup=5, + _n_repeat=5) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))