diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py index 88527f7b3..adda1f0f1 100644 --- a/examples/blocksparse_attention/test_example_blocksparse_attention.py +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): def test_example_triton_sparse_gqa_decode_varlen_indice(): example_triton_sparse_gqa_decode_varlen_indice.main( - batch=16, - heads=16, - heads_kv=8, - max_cache_seqlen=4096, + batch=8, + heads=8, + heads_kv=4, + max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, @@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask(): batch=16, heads=16, heads_kv=8, - max_cache_seqlen=4096, + max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 4c2f574c0..102ac2021 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \ return x_fp8 -def main(M=8192, N=8192, BG=2, blk_m=8): +def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] if dtype == "float": x = torch.randn(M, N, device="cuda", dtype=torch.float32) elif dtype == "float16": @@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8): x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) else: raise ValueError(f"Unsupported dtype: {dtype}") - batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32) + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) M_max = int(ceil_div(batch_sizes.max(), 128) * 128) print("batch_sizes:", batch_sizes) diff --git a/examples/cast/test_example_cast.py b/examples/cast/test_example_cast.py index 2f978c1d4..1ca000eb2 100644 --- a/examples/cast/test_example_cast.py +++ b/examples/cast/test_example_cast.py @@ -4,11 +4,12 @@ def test_example_group_per_split_token_cast_to_fp8(): - example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8) + example_group_per_split_token_cast_to_fp8.main( + M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) def test_example_per_token_cast_to_fp8(): - example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8) + example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8) if __name__ == "__main__": diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py index 971a3206c..33ab00e4c 100644 --- a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -13,7 +13,7 @@ def test_example_topk_selector(): def test_example_fp8_lighting_indexer(): - test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1) + test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) @tilelang.testing.requires_cuda @@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd(): def test_example_sparse_mla_fwd_pipelined(): # small shapes for testing test_sparse_mla_fwd_pipelined( - S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_sparse_mla_bwd(): test_sparse_mla_bwd( - S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) if __name__ == "__main__": diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index 527d89cd0..f4932aee9 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined(): @tilelang.testing.requires_cuda def test_example_mha_bwd(): - example_mha_bwd.main(BATCH=1) + example_mha_bwd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) @tilelang.testing.requires_cuda def test_example_mha_bwd_bhsd(): - example_mha_bwd_bhsd.main(BATCH=1) + example_mha_bwd_bhsd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_mha_bwd_wgmma_pipelined(): - example_mha_bwd_wgmma_pipelined.main(BATCH=1) + example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) @tilelang.testing.requires_cuda @@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd(): @tilelang.testing.requires_cuda def test_example_mha_fwd_varlen(): - example_mha_fwd_varlen.main() + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64) if __name__ == "__main__": diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index b4285a64f..3eabc9a76 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal): 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) -def main(): - BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128 - causal = False +def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD total_flops = 2 * flops_per_matmul if causal: diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py index a6ec1c68e..c728dfe0e 100644 --- a/examples/flash_decoding/test_example_flash_decoding.py +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -12,7 +12,7 @@ def test_example_example_gqa_decode(): def test_example_example_mha_inference(): - example_mha_inference.main() + example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) if __name__ == "__main__":