2727import onnx
2828import onnxruntime
2929import pytest
30- from onnx import ModelProto , TensorProto , helper , mapping
30+ from onnx import ModelProto , TensorProto , helper
3131
3232import tvm
3333import tvm .testing
@@ -62,7 +62,7 @@ def generate_random_inputs(
6262def generate_random_value (shape , elem_type ) -> np .ndarray :
6363 # Extract datatype for the input.
6464 if elem_type :
65- dtype = str (onnx . mapping . TENSOR_TYPE_TO_NP_TYPE [ elem_type ] )
65+ dtype = str (helper . tensor_dtype_to_np_dtype ( elem_type ) )
6666 else :
6767 dtype = "float32"
6868
@@ -87,6 +87,7 @@ def check_correctness(
8787 opset : int = 14 ,
8888 rtol : float = 1e-7 ,
8989 atol : float = 1e-5 ,
90+ check_dtypes : bool = False ,
9091) -> None :
9192 """Run an onnx model in both onnxruntime and TVM through our importer
9293 confirm that the results match. Otherwise, an exception will be raised.
@@ -104,6 +105,8 @@ def check_correctness(
104105 atol: float
105106 Set the tolerance of correctness checking. Some ops may be show more
106107 arithmetic variance than others.
108+ check_dtypes: bool
109+ Check if data types are the same.
107110 """
108111 # Configure model format.
109112 if ir_version is not None :
@@ -152,17 +155,35 @@ def check_correctness(
152155 # while the ONNX output number is one, which is a list
153156 tvm_output = [tvm_output ]
154157
158+ def _get_numpy_subdtype (narray ):
159+ if np .issubdtype (narray .dtype , np .integer ):
160+ return "integer"
161+ elif np .issubdtype (narray .dtype , np .floating ):
162+ return "floating"
163+ elif np .issubdtype (narray .dtype , np .bool_ ):
164+ return "bool"
165+ elif np .issubdtype (narray .dtype , np .complexfloating ):
166+ return "complexfloating"
167+ else :
168+ return "other"
169+
155170 def _check_output (tvm_out , ort_out ):
156171 if isinstance (tvm_out , tuple ) and isinstance (ort_out , (tvm .runtime .ShapeTuple , list )):
157172 assert len (tvm_out ) == len (ort_out ), "Unequal number of outputs"
158173 for tvm_out_i , ort_out_i in zip (tvm_out , ort_out ):
159174 _check_output (tvm_out_i , ort_out_i )
160175 elif isinstance (tvm_out , tvm .nd .NDArray ) and isinstance (ort_out , np .ndarray ):
176+ if check_dtypes :
177+ assert tvm_out .numpy ().dtype == ort_out .dtype
161178 tvm .testing .assert_allclose (tvm_out .numpy (), ort_out , rtol = rtol , atol = atol )
162179 elif isinstance (tvm_out , tvm .runtime .ShapeTuple ) and isinstance (ort_out , np .ndarray ):
163180 shape_out = tvm .nd .array ([int (i ) for i in tvm_out ])
181+ if check_dtypes :
182+ assert _get_numpy_subdtype (shape_out .numpy ()) == _get_numpy_subdtype (ort_out )
164183 tvm .testing .assert_allclose (shape_out .numpy (), ort_out , rtol = rtol , atol = atol )
165184 elif isinstance (tvm_out , (int , float , bool )) and isinstance (ort_out , np .ndarray ):
185+ if check_dtypes :
186+ assert _get_numpy_subdtype (np .array (tvm_out )) == _get_numpy_subdtype (ort_out )
166187 tvm .testing .assert_allclose (np .array (tvm_out ), ort_out , rtol = rtol , atol = atol )
167188 else :
168189 raise ValueError (f"Unsupported types: { type (tvm_out )} , { type (ort_out )} " )
@@ -267,7 +288,7 @@ def verify_binary(
267288 )
268289
269290 model = helper .make_model (graph , producer_name = "binary_test" )
270- check_correctness (model , opset = opset )
291+ check_correctness (model , opset = opset , check_dtypes = True )
271292
272293
273294def verify_binary_scalar (op_name , attrs = {}, domain = None , dtype = TensorProto .INT32 , opset = 14 ):
@@ -282,7 +303,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32
282303 )
283304
284305 model = helper .make_model (graph , producer_name = "binary_test" )
285- check_correctness (model , opset = opset )
306+ check_correctness (model , opset = opset , check_dtypes = True )
286307
287308
288309def verify_compare (op_name , shape , attrs = {}, domain = None ):
@@ -1897,7 +1918,7 @@ def verify_constantofshape(input_dim, value, dtype):
18971918 ["input" ],
18981919 ["output" ],
18991920 value = helper .make_tensor (
1900- "value" , mapping . NP_TYPE_TO_TENSOR_TYPE [ np .dtype (dtype )] , (1 ,), (value ,)
1921+ "value" , helper . np_dtype_to_tensor_dtype ( np .dtype (dtype )) , (1 ,), (value ,)
19011922 ),
19021923 )
19031924
@@ -1917,7 +1938,7 @@ def verify_constantofshape(input_dim, value, dtype):
19171938 ],
19181939 outputs = [
19191940 helper .make_tensor_value_info (
1920- "output" , mapping . NP_TYPE_TO_TENSOR_TYPE [ np .dtype (dtype )] , input_dim
1941+ "output" , helper . np_dtype_to_tensor_dtype ( np .dtype (dtype )) , input_dim
19211942 )
19221943 ],
19231944 )
@@ -2299,7 +2320,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
22992320
23002321 inputs = [
23012322 helper .make_tensor_value_info (
2302- "input" , mapping . NP_TYPE_TO_TENSOR_TYPE [ indata .dtype ] , indata_shape
2323+ "input" , helper . np_dtype_to_tensor_dtype ( indata .dtype ) , indata_shape
23032324 )
23042325 ]
23052326
@@ -2333,7 +2354,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
23332354 outputs = [
23342355 helper .make_tensor_value_info (
23352356 f"output_{ i } " ,
2336- mapping . NP_TYPE_TO_TENSOR_TYPE [ indata .dtype ] ,
2357+ helper . np_dtype_to_tensor_dtype ( indata .dtype ) ,
23372358 list (outdata_shapes [i ]),
23382359 )
23392360 for i in range (len (split_index ))
0 commit comments