Skip to content

Commit 2ad2de4

Browse files
committed
Make TIR minimal.
1 parent a789d26 commit 2ad2de4

File tree

1 file changed

+2
-26
lines changed

1 file changed

+2
-26
lines changed

tests/python/unittest/test_tir_transform_lower_tvm_builtin.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -174,40 +174,16 @@ def build_tir():
174174

175175
def test_lower_overflow_int32():
176176
@T.prim_func
177-
def variance4(
178-
rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112)), "float32"),
179-
T_divide: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float32"),
180-
):
177+
def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112)), "float32")):
181178
T.func_attr({"global_symbol": "variance4", "tir.noalias": True})
182179
rxplaceholder_red = T.allocate([32], "float32", "global")
183180
T_subtract = T.allocate([822083584], "float32", "global")
184181
rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
185182
rxplaceholder_1 = T.Buffer((T.int64(822083584),), data=rxplaceholder.data)
186-
for ax1, k2 in T.grid(32, 25690112):
187-
if k2 == 0:
188-
rxplaceholder_red_1[ax1] = T.float32(0)
189-
rxplaceholder_red_1[ax1] = (
190-
rxplaceholder_red_1[ax1] + rxplaceholder_1[ax1 * 25690112 + k2]
191-
)
192-
rxplaceholder_red_2 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
193-
for ax1 in range(32):
194-
rxplaceholder_red_2[ax1] = rxplaceholder_red_1[ax1] * T.float32(3.8925482302295915e-08)
195183
T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract)
196184
for ax1, ax2 in T.grid(32, 25690112):
197185
cse_var_1: T.int32 = ax1 * 25690112 + ax2
198-
T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_2[ax1]
199-
T_subtract_2 = T.Buffer((T.int64(822083584),), data=T_subtract)
200-
for ax1, ax2 in T.grid(32, 25690112):
201-
cse_var_2: T.int32 = ax1 * 25690112 + ax2
202-
T_subtract_2[cse_var_2] = T_subtract_1[cse_var_2] * T_subtract_1[cse_var_2]
203-
rxplaceholder_red_3 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
204-
for ax1, k2 in T.grid(32, 25690112):
205-
if k2 == 0:
206-
rxplaceholder_red_3[ax1] = T.float32(0)
207-
rxplaceholder_red_3[ax1] = rxplaceholder_red_3[ax1] + T_subtract_2[ax1 * 25690112 + k2]
208-
for ax1 in range(32):
209-
T_divide_1 = T.Buffer((T.int64(32),), data=T_divide.data)
210-
T_divide_1[ax1] = rxplaceholder_red_3[ax1] * T.float32(3.8925482302295915e-08)
186+
T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_1[ax1]
211187

212188
func = variance4
213189
tvm.build(func, target="llvm") # should not crash

0 commit comments

Comments
 (0)