diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index ef3d8baed..bd43db515 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -5,7 +5,6 @@ import argparse import tilelang import tilelang.language as T -from tilelang.autotuner import autotune torch.manual_seed(0) tilelang.disable_cache() @@ -198,7 +197,7 @@ def get_configs(): return configs -@autotune(configs=get_configs(), warmup=10, rep=10) +# @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") def flashattn( batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 @@ -438,127 +437,6 @@ def grid(META): return O, S -def test_equal_seqlen_decode_main(args): - """Test decode kernel with equal sequence lengths""" - print("Testing decode kernel with equal sequence lengths") - - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - k_seqlen = args.k_seqlen - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size - dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 - - # For decode, query is just 1 token per batch - q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) - k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) - v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) - softmax_scale = 1.0 / math.sqrt(head_size) - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(f"Using sink attention with sink values: {sink}") - - # Convert to varlen format for K, V - k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) - v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) - - # Generate cumulative sequence lengths - cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) - max_seqlen_k = k_seqlen - - print(f"q shape: {q.shape}") - print(f"k_varlen shape: {k_varlen.shape}") - print(f"v_varlen shape: {v_varlen.shape}") - - num_tokens, q_h, head_size = q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) - - # Test our decode kernel - O_triton, S_triton = flash_attn_with_attn_pool_decode( - q, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - ) - O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( - q, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - tl_kernel=tl_kernel, - ) - for i in range(batch_size): - S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 - - # Compute torch reference - q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] - k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] - v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] - - if sink is None: - # Standard scaled dot-product attention - logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] - attn_weights = torch.softmax(logits, dim=-1) - O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] - else: - # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] - - sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] - logits_max = torch.max(logits, dim=-1, keepdim=True).values - logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) - sinks = torch.exp(sink_expanded - logits_or_sinks_max) - unnormalized_scores = torch.exp(logits - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - attn_weights = unnormalized_scores / normalizer - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] - - # Compute attention score pooling - attn_score_pooled = torch.max_pool2d( - attn_weights.squeeze(2), # [b, q_heads, k_seqlen] - kernel_size=(q_heads, block_size), - stride=(q_heads, block_size), - ceil_mode=True, - ).to(torch.float16) - - print("S_tilelang", S_tilelang) - print("attn_score_pooled", attn_score_pooled) - - max_diff_o = torch.max(torch.abs(O_triton - O_torch)) - max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) - max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) - - print(f"Max difference in O: {max_diff_o.item()}") - print(f"Max difference in S: {max_diff_s.item()}") - print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") - print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") - assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" - assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" - print("✅ All tests passed!") - - def test_varlen_decode_main(args): """Test decode kernel with variable sequence lengths""" batch_size = args.batch_size @@ -742,16 +620,23 @@ def test_varlen_decode_main(args): print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + max_diff_s_tl = torch.max( + torch.abs( + S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)] + ) + ) print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( - f"Score mismatch: {max_diff_s_tl.item()}" - ) + assert torch.allclose( + S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], + attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)], + atol=1e-2, + rtol=1e-2, + ), f"Score mismatch: {max_diff_s_tl.item()}" print("✅ All tests passed!") @@ -882,6 +767,23 @@ def speed_benchmark_decode_comparison(args): print(f"Speedup: {(triton_time / tilelang_time):.3f}") +def main(): + args = argparse.Namespace( + batch_size=1, + q_heads=32, + kv_heads=8, + k_seqlen=8192, + head_size=128, + block_size=128, + dtype=T.float16, + ) + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + test_varlen_decode_main(args) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") @@ -889,7 +791,7 @@ def speed_benchmark_decode_comparison(args): parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") - parser.add_argument("--block_size", type=int, default=64, help="Block size for computation") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") @@ -897,13 +799,11 @@ def speed_benchmark_decode_comparison(args): parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") args = parser.parse_args() args.test_sink = True - args.test_varlen = False + args.test_varlen = True args.dtype = T.float16 args.num_split = 1 if args.benchmark: speed_benchmark_decode_comparison(args) - elif args.test_varlen: - test_varlen_decode_main(args) else: - test_equal_seqlen_decode_main(args) + test_varlen_decode_main(args) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py index 0984e7075..87a828b50 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -199,138 +199,6 @@ def flash_attn_with_attn_pool_decode_tilelang( return O_tl, S_tl -def test_equal_seqlen_decode_main(args): - """Test decode kernel with equal sequence lengths""" - print("Testing decode kernel with equal sequence lengths") - - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - k_seqlen = args.k_seqlen - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size - page_block_size = args.page_block_size - dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 - - # For decode, query is just 1 token per batch - q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) - k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) - v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) - softmax_scale = 1.0 / math.sqrt(head_size) - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(f"Using sink attention with sink values: {sink}") - - # Convert to varlen format for K, V - k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() - v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() - - # Generate cumulative sequence lengths - cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) - max_seqlen_k = k_seqlen - - print(f"q shape: {q.shape}") - print(f"k_varlen shape: {k_varlen.shape}") - print(f"v_varlen shape: {v_varlen.shape}") - - num_tokens, q_h, head_size = q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) - - block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) - block_cnt = 0 - for i in range(batch): - cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() - for j in range(math.ceil(cur_seqlen / page_block_size)): - block_table[i, j] = block_cnt - block_cnt += 1 - block_cnt = 0 - - # Test our decode kernel - O_triton, S_triton = flash_attn_with_attn_pool_decode( - q, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - ) - O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( - q, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - tl_kernel=tl_kernel, - block_table=block_table, - ) - for i in range(batch_size): - S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 - - # Compute torch reference - q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] - k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] - v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] - - if sink is None: - # Standard scaled dot-product attention - logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] - attn_weights = torch.softmax(logits, dim=-1) - O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] - else: - # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] - - sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] - logits_max = torch.max(logits, dim=-1, keepdim=True).values - logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) - sinks = torch.exp(sink_expanded - logits_or_sinks_max) - unnormalized_scores = torch.exp(logits - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - attn_weights = unnormalized_scores / normalizer - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] - - # Compute attention score pooling - attn_score_pooled = torch.max_pool2d( - attn_weights.squeeze(2), # [b, q_heads, k_seqlen] - kernel_size=(q_heads, block_size), - stride=(q_heads, block_size), - ceil_mode=True, - ).to(torch.float16) - - print("S_tilelang", S_tilelang) - print("attn_score_pooled", attn_score_pooled) - - max_diff_o = torch.max(torch.abs(O_triton - O_torch)) - max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) - max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) - - print(f"Max difference in O: {max_diff_o.item()}") - print(f"Max difference in S: {max_diff_s.item()}") - print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") - print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") - assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" - assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" - print("✅ All tests passed!") - - def test_varlen_decode_main(args): """Test decode kernel with variable sequence lengths""" batch_size = args.batch_size @@ -651,6 +519,24 @@ def speed_benchmark_decode_comparison(args): print(f"Speedup: {(triton_time / tilelang_time):.3f}") +def main(): + args = argparse.Namespace( + batch_size=1, + q_heads=32, + kv_heads=8, + k_seqlen=8192, + head_size=128, + block_size=128, + dtype=T.float16, + ) + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + args.page_block_size = 128 + test_varlen_decode_main(args) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") @@ -673,7 +559,5 @@ def speed_benchmark_decode_comparison(args): if args.benchmark: speed_benchmark_decode_comparison(args) - elif args.test_varlen: - test_varlen_decode_main(args) else: - test_equal_seqlen_decode_main(args) + test_varlen_decode_main(args) diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py index c728dfe0e..a02a92097 100644 --- a/examples/flash_decoding/test_example_flash_decoding.py +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -2,6 +2,8 @@ import example_gqa_decode import example_mha_inference +import example_gqa_decode_varlen_logits +import example_gqa_decode_varlen_logits_paged # TODO(lei): fix the correctness of gqa decode on sm90 @@ -15,5 +17,13 @@ def test_example_example_mha_inference(): example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) +def test_example_example_gqa_decode_varlen_logits(): + example_gqa_decode_varlen_logits.main() + + +def test_example_example_gqa_decode_varlen_logits_paged(): + example_gqa_decode_varlen_logits_paged.main() + + if __name__ == "__main__": tilelang.testing.main()