Skip to content

Commit c937bc5

Browse files
committed
[CI]:Reduce test shapes to avoid OOM errors during CI.
1 parent 7211164 commit c937bc5

File tree

9 files changed

+30
-15
lines changed

9 files changed

+30
-15
lines changed

examples/blocksparse_attention/test_example_blocksparse_attention.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,27 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
2424

2525

2626
def test_example_triton_sparse_gqa_decode_varlen_indice():
27-
example_triton_sparse_gqa_decode_varlen_indice.main()
27+
example_triton_sparse_gqa_decode_varlen_indice.main(
28+
batch=16,
29+
heads=16,
30+
heads_kv=8,
31+
max_cache_seqlen=4096,
32+
dim=128,
33+
dim_v=128,
34+
sparse_ratio=0.8,
35+
block_size=32)
2836

2937

3038
def test_example_triton_sparse_gqa_decode_varlen_mask():
31-
example_triton_sparse_gqa_decode_varlen_mask.main()
39+
example_triton_sparse_gqa_decode_varlen_mask.main(
40+
batch=16,
41+
heads=16,
42+
heads_kv=8,
43+
max_cache_seqlen=4096,
44+
dim=128,
45+
dim_v=128,
46+
sparse_ratio=0.8,
47+
block_size=32)
3248

3349

3450
if __name__ == "__main__":

examples/cast/example_group_per_split_token_cast_to_fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
161161
return x_fp8
162162

163163

164-
def main():
165-
M, N, BG, blk_m = 8192, 8192, 2, 8
164+
def main(M=8192, N=8192, BG=2, blk_m=8):
166165
if dtype == "float":
167166
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
168167
elif dtype == "float16":

examples/cast/example_per_token_cast_to_fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
7979
return x_fp8, (x_amax / 448.0).view(m, -1)
8080

8181

82-
def main():
83-
M, N, blk_m = 8192, 8192, 8
82+
def main(M=8192, N=8192, blk_m=8):
8483
kernel = per_token_cast_to_fp8(M, N, blk_m)
8584
print(kernel.get_kernel_source())
8685
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)

examples/cast/test_example_cast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
def test_example_group_per_split_token_cast_to_fp8():
7-
example_group_per_split_token_cast_to_fp8.main()
7+
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8)
88

99

1010
def test_example_per_token_cast_to_fp8():
11-
example_per_token_cast_to_fp8.main()
11+
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8)
1212

1313

1414
if __name__ == "__main__":

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_example_topk_selector():
1313

1414

1515
def test_example_fp8_lighting_indexer():
16-
test_fp8_lighting_indexer()
16+
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1)
1717

1818

1919
@tilelang.testing.requires_cuda

examples/dynamic_shape/example_dynamic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def ref_program(A, B):
9696
print(f"Latency: {latency} ms")
9797

9898

99-
def main():
100-
M, N, K = 16384, 16384, 16384
99+
def main(M=16384, N=16384, K=16384):
101100
block_M, block_N, block_K = 128, 128, 32
102101
trans_A, trans_B = False, False
103102
in_dtype, out_dtype = "float16", "float16"

examples/dynamic_shape/test_example_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
def test_example_dynamic():
6-
example_dynamic.main()
6+
example_dynamic.main(M=1024, N=1024, K=1024)
77

88

99
if __name__ == "__main__":

examples/flash_attention/test_example_flash_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def test_example_mha_bwd_wgmma_pipelined():
4444
@tilelang.testing.requires_cuda
4545
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
4646
def test_example_gqa_fwd_bshd_wgmma_pipelined():
47-
example_gqa_fwd_bshd_wgmma_pipelined.main()
47+
example_gqa_fwd_bshd_wgmma_pipelined.main(
48+
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
4849

4950

5051
@tilelang.testing.requires_cuda
5152
def test_example_gqa_fwd_bshd():
52-
example_gqa_fwd_bshd.main()
53+
example_gqa_fwd_bshd.main(
54+
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
5355

5456

5557
@tilelang.testing.requires_cuda

testing/python/issue/test_tilelang_issue_96.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32):
5353

5454
def test_pipeline_large_matrix():
5555
"""Test pipeline stages with large matrix multiplication (8192x8192)"""
56-
run_gemm_pipeline_test(8192)
56+
run_gemm_pipeline_test(4096)
5757

5858

5959
def test_pipeline_small_matrix():

0 commit comments

Comments
 (0)