@@ -252,20 +252,25 @@ def intrin_func(ins, outs):
252252 assert ins [0 ].shape [0 ].value == n
253253 return tvm .tir .call_packed ("vadd" , ins [0 ].data , outs [0 ].data , ins [0 ].shape [0 ])
254254
255- intrin = te .decl_tensor_intrin (z .op , intrin_func )
255+ intrin = te .decl_tensor_intrin (z .op , intrin_func , default_buffer_params = { "offset_factor" : n } )
256256 assert intrin .op == z .op
257257 assert intrin .reduce_init is None
258258 assert tuple (intrin .inputs ) == tuple (z .op .input_tensors )
259259 assert intrin .buffers [0 ].shape [0 ].value == n
260260 m = 32
261- x = te .placeholder ((m ,), name = "x" )
262- y = te .placeholder ((m ,), name = "y" )
263- z = te .compute (x .shape , lambda i : x [i ] + y [i ], name = "z" )
264- s = te .create_schedule (z .op )
265- xo , xi = s [z ].split (z .op .axis [0 ], factor = n )
266- s [z ].tensorize (xi , intrin )
267- assert s [z ].iter_var_attrs [xi ].tensor_intrin == intrin
268- assert s [z ].iter_var_attrs [xi ].iter_type == tvm .te .schedule .IterVar .Tensorized
261+ X = te .placeholder ((m ,), name = "X" )
262+ Y = te .placeholder ((m ,), name = "Y" )
263+ Z = te .compute (X .shape , lambda i : X [i ] + Y [i ], name = "Z" )
264+ s = te .create_schedule (Z .op )
265+ xo , xi = s [Z ].split (Z .op .axis [0 ], factor = n )
266+ s [Z ].tensorize (xi , intrin )
267+ stmt = tvm .lower (s , [X , Y , Z ])["main" ].body
268+ assert isinstance (stmt .body , tvm .tir .Evaluate )
269+ assert str (stmt .body .value .args [0 ]) == '"vadd"'
270+ assert str (stmt .body .value .args [1 ]) == "X"
271+ assert str (stmt .body .value .args [2 ]) == "Z"
272+ assert s [Z ].iter_var_attrs [xi ].tensor_intrin == intrin
273+ assert s [Z ].iter_var_attrs [xi ].iter_type == tvm .te .schedule .IterVar .Tensorized
269274
270275
271276def test_tensor_intrin_scalar_params ():
0 commit comments