Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():

def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
batch=8,
heads=8,
heads_kv=4,
max_cache_seqlen=2048,
dim=128,
dim_v=128,
sparse_ratio=0.8,
Expand All @@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
max_cache_seqlen=1024,
dim=128,
dim_v=128,
sparse_ratio=0.8,
Expand Down
6 changes: 4 additions & 2 deletions examples/cast/example_group_per_split_token_cast_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8


def main(M=8192, N=8192, BG=2, blk_m=8):
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None:
batch_sizes = [2048, 6144]
if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16":
Expand All @@ -170,7 +172,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32)
batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128)

print("batch_sizes:", batch_sizes)
Expand Down
5 changes: 3 additions & 2 deletions examples/cast/test_example_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@


def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8)
example_group_per_split_token_cast_to_fp8.main(
M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])


def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8)
example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_example_topk_selector():


def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1)
test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)


@tilelang.testing.requires_cuda
Expand All @@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd():
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
test_sparse_mla_fwd_pipelined(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
test_sparse_mla_bwd(
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)


if __name__ == "__main__":
Expand Down
20 changes: 16 additions & 4 deletions examples/flash_attention/test_example_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,30 @@ def test_example_gqa_bwd_wgmma_pipelined():

@tilelang.testing.requires_cuda
def test_example_mha_bwd():
example_mha_bwd.main(BATCH=1)
example_mha_bwd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)


@tilelang.testing.requires_cuda
def test_example_mha_bwd_bhsd():
example_mha_bwd_bhsd.main(BATCH=1)
example_mha_bwd_bhsd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1)
example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)


@tilelang.testing.requires_cuda
Expand Down Expand Up @@ -84,7 +96,7 @@ def test_example_mha_fwd_bshd():

@tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main()
example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions examples/flash_decoding/example_mha_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal):
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)


def main():
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
causal = False
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
Expand Down
2 changes: 1 addition & 1 deletion examples/flash_decoding/test_example_flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_example_example_gqa_decode():


def test_example_example_mha_inference():
example_mha_inference.main()
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)


if __name__ == "__main__":
Expand Down
Loading