Skip to content

Commit 9899be1

Browse files
committed
[Refactor] Simplify tensor_null_test function and remove ptr_null_test
This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations.
1 parent e2cbf27 commit 9899be1

File tree

1 file changed

+3
-63
lines changed

1 file changed

+3
-63
lines changed

testing/python/jit/test_tilelang_jit_nullptr.py

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,14 @@
77

88

99
@tl.jit
10-
def ptr_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
11-
12-
@T.prim_func
13-
def main(
14-
a_ptr: T.ptr,
15-
b_ptr: T.ptr,
16-
c_ptr: T.ptr,
17-
bias_ptr: T.ptr,
18-
m: T.int32,
19-
n: T.int32,
20-
k: T.int32,
21-
with_bias: T.bool,
22-
):
23-
A = T.make_tensor(a_ptr, (m, k), dtype)
24-
B = T.make_tensor(b_ptr, (k, n), dtype)
25-
C = T.make_tensor(c_ptr, (m, n), accum_dtype)
26-
Bias = T.make_tensor(bias_ptr, (n), accum_dtype)
27-
28-
# Initialize Kernel Context
29-
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
30-
A_shared = T.alloc_shared((block_M, block_K), dtype)
31-
B_shared = T.alloc_shared((block_N, block_K), dtype)
32-
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
33-
34-
T.clear(C_local)
35-
36-
for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
37-
# Copy tile of A
38-
T.copy(A[by * block_M, ko * block_K], A_shared)
39-
T.copy(B[bx * block_N, ko * block_K], B_shared)
40-
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
41-
42-
if with_bias:
43-
for i, j in T.Parallel(block_M, block_N):
44-
C_local[i, j] += Bias[bx * block_N + j]
45-
46-
T.copy(C_local, C[by * block_M, bx * block_N])
47-
48-
return main
49-
50-
51-
@tl.jit
52-
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
10+
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False):
5311

5412
@T.prim_func
5513
def main(
5614
A: T.Tensor((M, K), dtype),
5715
B: T.Tensor((K, N), dtype),
5816
C: T.Tensor((M, N), accum_dtype),
5917
Bias: T.Tensor((N), accum_dtype),
60-
with_bias: T.bool,
6118
):
6219
# Initialize Kernel Context
6320
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
@@ -83,29 +40,12 @@ def main(
8340

8441

8542
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
86-
kernel = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
8743

8844
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
8945
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
9046
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
91-
d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype))
92-
kernel(a, b, c, None, M, N, K, False)
93-
94-
ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
95-
ref_with_bias = ref_no_bias + d
96-
97-
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
98-
99-
kernel(a, b, c, d, M, N, K, True)
100-
101-
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
102-
103-
kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
104-
kernel(a, b, c, None, False)
105-
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
106-
kernel(a, b, c, d, True)
107-
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
108-
47+
kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False)
48+
kernel(a, b, c, None)
10949

11050
def test_nullptr():
11151
run_test(1024, 1024, 1024, 128, 128, 32)

0 commit comments

Comments
 (0)