Skip to content

Commit 5c62d00

Browse files
authored
[Testing] Move TMA 1D and test for its functionality (#1167)
* [Testing] Move TMA 1D and test for its functionality * [Lint]
1 parent 54d4bd6 commit 5c62d00

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed
Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import tilelang.testing
22
import example_elementwise_add
3-
import example_elementwise_add_tma_1d
43

54

65
def test_example_elementwise_add():
76
example_elementwise_add.main()
87

98

10-
def test_example_elementwise_add_tma_1d():
11-
example_elementwise_add_tma_1d.main()
12-
13-
149
if __name__ == "__main__":
1510
tilelang.testing.main()

examples/elementwise/example_elementwise_add_tma_1d.py renamed to testing/python/language/test_tilelang_language_tma_1d.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import argparse
1+
import torch
22
import tilelang
33
import tilelang.language as T
4-
import torch
54

65

76
def ref_program(x, y):
@@ -30,23 +29,29 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.
3029
return elem_add
3130

3231

33-
def main():
34-
parser = argparse.ArgumentParser()
35-
parser.add_argument("--m", type=int, default=128)
36-
parser.add_argument("--n", type=int, default=128)
37-
args, _ = parser.parse_known_args()
38-
M, N = args.m, args.n
39-
32+
def run_elementwise_add(M, N):
4033
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
4134
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
4235

4336
# Default config
44-
config = {"block_M": 128, "block_N": 128, "threads": 128}
37+
block_M, block_N = 128, 128
38+
config = {"block_M": block_M, "block_N": block_N, "threads": 128}
4539
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
4640

4741
out = kernel(a, b)
4842
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
49-
print("All passed!")
43+
44+
code = kernel.get_kernel_source()
45+
if block_N == N:
46+
assert "tma_load" in code and "CUtensorMap" not in code
47+
else:
48+
assert "tma_load" in code and "CUtensorMap" in code
49+
50+
51+
def main():
52+
run_elementwise_add(128, 128)
53+
run_elementwise_add(256, 128)
54+
run_elementwise_add(256, 256)
5055

5156

5257
if __name__ == "__main__":

0 commit comments

Comments
 (0)