Skip to content

Commit d12a636

Browse files
new-TonyWangwangtongyu
andauthored
Refactor test to make it easier for user to understand how tensor_intrin works (#14017)
Signed-off-by: wangtongyu <[email protected]> Co-authored-by: wangtongyu <[email protected]>
1 parent d7253fb commit d12a636

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

tests/python/unittest/test_te_schedule.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,25 @@ def intrin_func(ins, outs):
252252
assert ins[0].shape[0].value == n
253253
return tvm.tir.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0])
254254

255-
intrin = te.decl_tensor_intrin(z.op, intrin_func)
255+
intrin = te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={"offset_factor": n})
256256
assert intrin.op == z.op
257257
assert intrin.reduce_init is None
258258
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
259259
assert intrin.buffers[0].shape[0].value == n
260260
m = 32
261-
x = te.placeholder((m,), name="x")
262-
y = te.placeholder((m,), name="y")
263-
z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
264-
s = te.create_schedule(z.op)
265-
xo, xi = s[z].split(z.op.axis[0], factor=n)
266-
s[z].tensorize(xi, intrin)
267-
assert s[z].iter_var_attrs[xi].tensor_intrin == intrin
268-
assert s[z].iter_var_attrs[xi].iter_type == tvm.te.schedule.IterVar.Tensorized
261+
X = te.placeholder((m,), name="X")
262+
Y = te.placeholder((m,), name="Y")
263+
Z = te.compute(X.shape, lambda i: X[i] + Y[i], name="Z")
264+
s = te.create_schedule(Z.op)
265+
xo, xi = s[Z].split(Z.op.axis[0], factor=n)
266+
s[Z].tensorize(xi, intrin)
267+
stmt = tvm.lower(s, [X, Y, Z])["main"].body
268+
assert isinstance(stmt.body, tvm.tir.Evaluate)
269+
assert str(stmt.body.value.args[0]) == '"vadd"'
270+
assert str(stmt.body.value.args[1]) == "X"
271+
assert str(stmt.body.value.args[2]) == "Z"
272+
assert s[Z].iter_var_attrs[xi].tensor_intrin == intrin
273+
assert s[Z].iter_var_attrs[xi].iter_type == tvm.te.schedule.IterVar.Tensorized
269274

270275

271276
def test_tensor_intrin_scalar_params():

0 commit comments

Comments
 (0)