Skip to content

Commit eec7e3d

Browse files
committed
[Testing] Move TMA 1D and test for its functionality
1 parent 10911e2 commit eec7e3d

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed
Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import tilelang.testing
22
import example_elementwise_add
3-
import example_elementwise_add_tma_1d
43

54

65
def test_example_elementwise_add():
76
example_elementwise_add.main()
87

98

10-
def test_example_elementwise_add_tma_1d():
11-
example_elementwise_add_tma_1d.main()
12-
13-
149
if __name__ == "__main__":
1510
tilelang.testing.main()

examples/elementwise/example_elementwise_add_tma_1d.py renamed to testing/python/language/test_tilelang_language_tma_1d.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
2+
import torch
23
import tilelang
34
import tilelang.language as T
4-
import torch
55

66

77
def ref_program(x, y):
@@ -30,23 +30,29 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.
3030
return elem_add
3131

3232

33-
def main():
34-
parser = argparse.ArgumentParser()
35-
parser.add_argument("--m", type=int, default=128)
36-
parser.add_argument("--n", type=int, default=128)
37-
args, _ = parser.parse_known_args()
38-
M, N = args.m, args.n
39-
33+
def run_elementwise_add(M, N):
4034
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
4135
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
4236

4337
# Default config
44-
config = {"block_M": 128, "block_N": 128, "threads": 128}
38+
block_M, block_N = 128, 128
39+
config = {"block_M": block_M, "block_N": block_N, "threads": 128}
4540
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
4641

4742
out = kernel(a, b)
4843
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
49-
print("All passed!")
44+
45+
code = kernel.get_kernel_source()
46+
if N == block_N:
47+
assert "tma_load" in code and "CUtensorMap" not in code
48+
else:
49+
assert "tma_load" in code and "CUtensorMap" in code
50+
51+
52+
def main():
53+
run_elementwise_add(128, 128)
54+
run_elementwise_add(256, 128)
55+
run_elementwise_add(256, 256)
5056

5157

5258
if __name__ == "__main__":

0 commit comments

Comments
 (0)