@@ -867,5 +867,39 @@ def test_unique_infer_struct_info_wrong_input_dtype():
867867 bb .normalize (relax .op .unique (x1 ))
868868
869869
870+ @pytest .mark .parametrize ("shape" , [(1 ,), (2 , 3 ), (4 , 5 , 6 )])
871+ def test_nonzero_infer_struct_info (shape ):
872+ bb = relax .BlockBuilder ()
873+ x0 = relax .Var ("x" , R .Tensor (shape , "bool" ))
874+
875+ _check_inference (
876+ bb ,
877+ relax .op .nonzero (x0 ),
878+ relax .TensorStructInfo (ndim = len (shape ) + 1 , dtype = "int64" ),
879+ )
880+
881+
882+ def test_nonzero_infer_struct_info_ndim_zero ():
883+ bb = relax .BlockBuilder ()
884+ x = relax .Var ("x" , R .Tensor ((), "bool" ))
885+
886+ _check_inference (
887+ bb ,
888+ relax .op .nonzero (x ),
889+ relax .TensorStructInfo (ndim = 2 , dtype = "int64" ),
890+ )
891+
892+
893+ def test_nonzero_infer_struct_info_wrong_input_dtype ():
894+ bb = relax .BlockBuilder ()
895+ x0 = relax .Var ("x" , relax .ShapeStructInfo ((2 , 3 , 4 )))
896+ x1 = relax .Var ("x" , relax .FuncStructInfo ([], R .Tensor ((2 , 3 , 4 ), "float32" )))
897+
898+ with pytest .raises (TVMError ):
899+ bb .normalize (relax .op .nonzero (x0 ))
900+ with pytest .raises (TVMError ):
901+ bb .normalize (relax .op .nonzero (x1 ))
902+
903+
870904if __name__ == "__main__" :
871905 tvm .testing .main ()
0 commit comments