Skip to content

Commit a5070fc

Browse files
committed
[Refactor]: Change the params in pytest to avoid oom error during ci
1 parent 54d4bd6 commit a5070fc

File tree

6 files changed

+31
-18
lines changed

6 files changed

+31
-18
lines changed

examples/blocksparse_attention/test_example_blocksparse_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
2525

2626
def test_example_triton_sparse_gqa_decode_varlen_indice():
2727
example_triton_sparse_gqa_decode_varlen_indice.main(
28-
batch=16,
29-
heads=16,
30-
heads_kv=8,
31-
max_cache_seqlen=4096,
28+
batch=8,
29+
heads=8,
30+
heads_kv=4,
31+
max_cache_seqlen=2048,
3232
dim=128,
3333
dim_v=128,
3434
sparse_ratio=0.8,
@@ -40,7 +40,7 @@ def test_example_triton_sparse_gqa_decode_varlen_mask():
4040
batch=16,
4141
heads=16,
4242
heads_kv=8,
43-
max_cache_seqlen=4096,
43+
max_cache_seqlen=1024,
4444
dim=128,
4545
dim_v=128,
4646
sparse_ratio=0.8,

examples/cast/test_example_cast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
def test_example_group_per_split_token_cast_to_fp8():
7-
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8)
7+
example_group_per_split_token_cast_to_fp8.main(M=4196, N=1024, BG=2, blk_m=8)
88

99

1010
def test_example_per_token_cast_to_fp8():
11-
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8)
11+
example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8)
1212

1313

1414
if __name__ == "__main__":

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_example_topk_selector():
1313

1414

1515
def test_example_fp8_lighting_indexer():
16-
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1)
16+
test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
1717

1818

1919
@tilelang.testing.requires_cuda
@@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd():
2929
def test_example_sparse_mla_fwd_pipelined():
3030
# small shapes for testing
3131
test_sparse_mla_fwd_pipelined(
32-
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
32+
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
3333

3434

3535
@tilelang.testing.requires_cuda
3636
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3737
def test_example_sparse_mla_bwd():
3838
test_sparse_mla_bwd(
39-
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
39+
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
4040

4141

4242
if __name__ == "__main__":

examples/flash_attention/test_example_flash_attention.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,33 @@ def test_example_gqa_bwd_wgmma_pipelined():
3333

3434
@tilelang.testing.requires_cuda
3535
def test_example_mha_bwd():
36-
example_mha_bwd.main(BATCH=1)
36+
example_mha_bwd.main(
37+
BATCH = 1,
38+
H = 16,
39+
N_CTX = 512,
40+
D_HEAD = 64,
41+
causal = False,)
3742

3843

3944
@tilelang.testing.requires_cuda
4045
def test_example_mha_bwd_bhsd():
41-
example_mha_bwd_bhsd.main(BATCH=1)
46+
example_mha_bwd_bhsd.main(
47+
BATCH = 1,
48+
H = 16,
49+
N_CTX = 512,
50+
D_HEAD = 64,
51+
causal = False,)
4252

4353

4454
@tilelang.testing.requires_cuda
4555
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
4656
def test_example_mha_bwd_wgmma_pipelined():
47-
example_mha_bwd_wgmma_pipelined.main(BATCH=1)
57+
example_mha_bwd_wgmma_pipelined.main(
58+
BATCH = 1,
59+
H = 16,
60+
N_CTX = 512,
61+
D_HEAD = 64,
62+
causal = False,)
4863

4964

5065
@tilelang.testing.requires_cuda
@@ -84,7 +99,7 @@ def test_example_mha_fwd_bshd():
8499

85100
@tilelang.testing.requires_cuda
86101
def test_example_mha_fwd_varlen():
87-
example_mha_fwd_varlen.main()
102+
example_mha_fwd_varlen.main(batch = 4, heads = 16, seq_len = 512, dim = 64)
88103

89104

90105
if __name__ == "__main__":

examples/flash_decoding/example_mha_inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal):
302302
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
303303

304304

305-
def main():
306-
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
307-
causal = False
305+
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
308306
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
309307
total_flops = 2 * flops_per_matmul
310308
if causal:

examples/flash_decoding/test_example_flash_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_example_example_gqa_decode():
1212

1313

1414
def test_example_example_mha_inference():
15-
example_mha_inference.main()
15+
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
1616

1717

1818
if __name__ == "__main__":

0 commit comments

Comments
 (0)