diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index 517b17949..0e8e21c43 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -4,55 +4,10 @@ import tilelang import tilelang.language as T import tilelang.testing -from einops import rearrange, repeat from tilelang.profiler import do_bench from varlen_utils import generate_random_padding_mask, generate_qkv -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), - upcast=True, -): - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - b, T, Hq, D = q.shape - S = k.shape[1] - scale = (1.0 / D) ** 0.5 - k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) - scores = torch.einsum("bthd,bshd->bhts", q, k) - left, right = window_size - left = S if left is None or left < 0 else int(left) - right = S if right is None or right < 0 else int(right) - t_idx = torch.arange(T, device=scores.device)[:, None] - s_idx = torch.arange(S, device=scores.device)[None, :] - visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right)) - visible_mask = visible_ts.unsqueeze(0).unsqueeze(0) - if key_padding_mask is not None: - k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s") - visible_mask = visible_mask & k_keep - neg_inf = torch.finfo(scores.dtype).min - scores = scores * scale - scores = scores.masked_fill(~visible_mask, neg_inf) - attention = torch.softmax(scores, dim=-1).to(v.dtype) - if query_padding_mask is not None: - q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1") - attention = attention.masked_fill(~q_keep, 0.0) - output = torch.einsum("bhts,bshd->bthd", attention, v) - if query_padding_mask is not None: - output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - @tilelang.jit( out_idx=[6], pass_configs={ @@ -110,8 +65,10 @@ def main( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) + offset = kv_current_seqlen - q_current_seqlen # always align on the right + max_visible_k_idx = offset + (bx + 1) * block_M loop_range = ( - T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) if is_causal else T.ceildiv(kv_current_seqlen, block_N) ) @@ -122,7 +79,7 @@ def main( if is_causal: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - (bx * block_M + i < k * block_N + j) + (bx * block_M + i + offset < k * block_N + j) or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0, @@ -158,9 +115,10 @@ def main( T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) + # When sq > skv, some tokens can see nothing + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + T.copy(acc_o, O_shared) for i, d in T.Parallel(block_M, dim): if bx * block_M + i < q_current_seqlen: Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] @@ -218,15 +176,22 @@ def main( out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) - out_ref, _ = attention_ref( - q, - k, - v, - query_padding_mask=query_padding_mask, - key_padding_mask=key_padding_mask, + import flash_attn + + fa_out_unpad = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, causal=is_causal, ) - torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) + fa_out = output_pad_fn(fa_out_unpad) + torch.testing.assert_close(out, fa_out, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) print("Tile-lang: {:.2f} ms".format(latency)) diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index 4077fe9f0..3d275348a 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -8,68 +8,10 @@ from tilelang.autotuner import set_autotune_inputs, autotune import torch -from einops import rearrange, repeat from varlen_utils import generate_random_padding_mask, generate_qkv import itertools -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - upcast=True, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - window_size: (int, int), left and right window size - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - dim = q.shape[-1] - scale = (1.0 / dim) ** 0.5 # log2(e) - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - scores = torch.einsum("bthd,bshd->bhts", q, k) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - # scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) - scores = scores * scale - attention = torch.softmax(scores, dim=-1).to(v.dtype) - - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - output = torch.einsum("bhts,bshd->bthd", attention, v) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - def get_configs(): iter_params = dict(block_M=[64, 128], block_N=[64, 128], num_stages=[0, 1, 2, 3], threads=[128, 256]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @@ -120,15 +62,12 @@ def main( head_idx = by q_start_idx = cu_seqlens_q[batch_idx] - k_start_idx = cu_seqlens_k[batch_idx] - v_start_idx = cu_seqlens_k[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] q_end_idx = cu_seqlens_q[batch_idx + 1] - k_end_idx = cu_seqlens_k[batch_idx + 1] - v_end_idx = cu_seqlens_k[batch_idx + 1] + kv_end_idx = cu_seqlens_k[batch_idx + 1] q_current_seqlen = q_end_idx - q_start_idx - k_current_seqlen = k_end_idx - k_start_idx - v_current_seqlen = v_end_idx - v_start_idx + kv_current_seqlen = kv_end_idx - kv_start_idx T.copy( Q_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :], Q_shared @@ -138,25 +77,30 @@ def main( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - loop_range = T.ceildiv(k_current_seqlen, block_N) + offset = kv_current_seqlen - q_current_seqlen # always align on the right + loop_range = ( + T.min(T.ceildiv(offset + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) for k in T.Pipelined(loop_range, num_stages=num_stages): # Q * K T.copy( - K_unpad[k_start_idx + k * block_N : k_start_idx + k * block_N + block_N, head_idx, :], K_shared + K_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], K_shared ) # OOB positions will be handled below if is_causal: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - (bx * block_M + i >= k * block_N + j) - and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), - -T.infinity(acc_s.dtype), + (bx * block_M + i + offset < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, 0, ) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0 + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -190,34 +134,34 @@ def main( # V * softmax(Q * K) T.copy( - V_unpad[v_start_idx + k * block_N : v_start_idx + k * block_N + block_N, head_idx, :], V_shared + V_unpad[kv_start_idx + k * block_N : kv_start_idx + k * block_N + block_N, head_idx, :], V_shared ) # OOB positions' weights are 0 T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] + # When sq > skv, some tokens can see nothing + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + T.copy(acc_o, O_shared) - T.copy( - O_shared, Output_unpad[q_start_idx + bx * block_M : q_start_idx + bx * block_M + block_M, head_idx, :] - ) # TMA will handle OOB + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] return main -def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, tune: bool = False): +def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False, tune: bool = False): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul tilelang.testing.set_random_seed(0) - causal = False if causal: total_flops *= 0.5 dtype = torch.float16 device = torch.device("cuda") - window_size = (-1, -1) q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) @@ -237,12 +181,11 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, t k, v, output_pad_fn, - dq_pad_fn, - dk_pad_fn, + _, + _, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] # unpadded query length - UK = k_unpad.shape[0] # unpadded key length UKV = k_unpad.shape[0] # unpadded query key length if tune: @@ -255,16 +198,6 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, t out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out = output_pad_fn(out_unpad) - out_ref, _ = attention_ref( - q, - k, - v, - query_padding_mask, - key_padding_mask, - causal=causal, - ) - torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) - import flash_attn fla_out_unpad = flash_attn.flash_attn_varlen_func( @@ -296,16 +229,14 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, t print(f"FA2: {total_flops / t * 1e-9} TFlops") -def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): +def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128, causal: bool = False): flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim total_flops = 2 * flops_per_matmul tilelang.testing.set_random_seed(0) - causal = False if causal: total_flops *= 0.5 dtype = torch.float16 device = torch.device("cuda") - window_size = (-1, -1) q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) @@ -327,7 +258,6 @@ def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, di dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) UQ = q_unpad.shape[0] - UK = k_unpad.shape[0] UKV = k_unpad.shape[0] kernel = flashattn(batch, UQ, UKV, heads, dim, causal, block_M=128, block_N=128, num_stages=2, threads=256) @@ -345,7 +275,8 @@ def run_kernel_only(): parser.add_argument("--heads", type=int, default=64, help="heads") parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", default=False, help="causal attention") parser.add_argument("--tune", action="store_true", default=False, help="tune the kernel") args = parser.parse_args() - main(args.batch, args.heads, args.seq_len, args.dim, args.tune) + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index da172bb62..a74bf071b 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -13,6 +13,7 @@ import example_mha_bwd_bshd_wgmma_pipelined import example_mha_fwd_bhsd import example_gqa_bwd_tma_reduce_varlen +import example_gqa_fwd_varlen @tilelang.testing.requires_cuda @@ -94,7 +95,14 @@ def test_example_mha_fwd_bshd(): @tilelang.testing.requires_cuda def test_example_mha_fwd_varlen(): - example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64) + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=False) + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64, causal=True) + + +@tilelang.testing.requires_cuda +def test_example_gqa_fwd_varlen(): + example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=False) + example_gqa_fwd_varlen.main(batch=4, heads=16, q_seqlen=512, k_seqlen=512, dim=64, is_causal=True) if __name__ == "__main__":