|
1 | | -import argparse |
| 1 | +import torch |
2 | 2 | import tilelang |
3 | 3 | import tilelang.language as T |
4 | | -import torch |
5 | 4 |
|
6 | 5 |
|
7 | 6 | def ref_program(x, y): |
@@ -30,23 +29,29 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T. |
30 | 29 | return elem_add |
31 | 30 |
|
32 | 31 |
|
33 | | -def main(): |
34 | | - parser = argparse.ArgumentParser() |
35 | | - parser.add_argument("--m", type=int, default=128) |
36 | | - parser.add_argument("--n", type=int, default=128) |
37 | | - args, _ = parser.parse_known_args() |
38 | | - M, N = args.m, args.n |
39 | | - |
| 32 | +def run_elementwise_add(M, N): |
40 | 33 | a = torch.randn(M, N, dtype=torch.float32, device="cuda") |
41 | 34 | b = torch.randn(M, N, dtype=torch.float32, device="cuda") |
42 | 35 |
|
43 | 36 | # Default config |
44 | | - config = {"block_M": 128, "block_N": 128, "threads": 128} |
| 37 | + block_M, block_N = 128, 128 |
| 38 | + config = {"block_M": block_M, "block_N": block_N, "threads": 128} |
45 | 39 | kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") |
46 | 40 |
|
47 | 41 | out = kernel(a, b) |
48 | 42 | torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) |
49 | | - print("All passed!") |
| 43 | + |
| 44 | + code = kernel.get_kernel_source() |
| 45 | + if block_N == N: |
| 46 | + assert "tma_load" in code and "CUtensorMap" not in code |
| 47 | + else: |
| 48 | + assert "tma_load" in code and "CUtensorMap" in code |
| 49 | + |
| 50 | + |
| 51 | +def main(): |
| 52 | + run_elementwise_add(128, 128) |
| 53 | + run_elementwise_add(256, 128) |
| 54 | + run_elementwise_add(256, 256) |
50 | 55 |
|
51 | 56 |
|
52 | 57 | if __name__ == "__main__": |
|
0 commit comments