2727import onnx
2828import onnxruntime
2929import pytest
30- from onnx import ModelProto , TensorProto , helper , mapping
30+ from onnx import ModelProto , TensorProto , helper
3131
3232import tvm
3333import tvm .testing
@@ -62,7 +62,7 @@ def generate_random_inputs(
6262def generate_random_value (shape , elem_type ) -> np .ndarray :
6363 # Extract datatype for the input.
6464 if elem_type :
65- dtype = str (onnx . mapping . TENSOR_TYPE_TO_NP_TYPE [ elem_type ] )
65+ dtype = str (helper . tensor_dtype_to_np_dtype ( elem_type ) )
6666 else :
6767 dtype = "float32"
6868
@@ -87,6 +87,7 @@ def check_correctness(
8787 opset : int = 14 ,
8888 rtol : float = 1e-7 ,
8989 atol : float = 1e-5 ,
90+ check_dtypes = False
9091) -> None :
9192 """Run an onnx model in both onnxruntime and TVM through our importer
9293 confirm that the results match. Otherwise, an exception will be raised.
@@ -104,6 +105,8 @@ def check_correctness(
104105 atol: float
105106 Set the tolerance of correctness checking. Some ops may be show more
106107 arithmetic variance than others.
108+ check_dtypes: bool
109+ Check if data types are the same.
107110 """
108111 # Configure model format.
109112 if ir_version is not None :
@@ -158,11 +161,15 @@ def _check_output(tvm_out, ort_out):
158161 for tvm_out_i , ort_out_i in zip (tvm_out , ort_out ):
159162 _check_output (tvm_out_i , ort_out_i )
160163 elif isinstance (tvm_out , tvm .nd .NDArray ) and isinstance (ort_out , np .ndarray ):
164+ if check_dtypes :
165+ assert _get_numpy_subdtype (tvm_out .numpy ()) == _get_numpy_subdtype (ort_out )
161166 tvm .testing .assert_allclose (tvm_out .numpy (), ort_out , rtol = rtol , atol = atol )
162167 elif isinstance (tvm_out , tvm .runtime .ShapeTuple ) and isinstance (ort_out , np .ndarray ):
163168 shape_out = tvm .nd .array ([int (i ) for i in tvm_out ])
164169 tvm .testing .assert_allclose (shape_out .numpy (), ort_out , rtol = rtol , atol = atol )
165170 elif isinstance (tvm_out , (int , float , bool )) and isinstance (ort_out , np .ndarray ):
171+ if check_dtypes :
172+ assert _get_numpy_subdtype (np .array (tvm_out )) == _get_numpy_subdtype (ort_out )
166173 tvm .testing .assert_allclose (np .array (tvm_out ), ort_out , rtol = rtol , atol = atol )
167174 else :
168175 raise ValueError (f"Unsupported types: { type (tvm_out )} , { type (ort_out )} " )
@@ -267,7 +274,7 @@ def verify_binary(
267274 )
268275
269276 model = helper .make_model (graph , producer_name = "binary_test" )
270- check_correctness (model , opset = opset )
277+ check_correctness (model , opset = opset , check_dtypes = True )
271278
272279
273280def verify_binary_scalar (op_name , attrs = {}, domain = None , dtype = TensorProto .INT32 , opset = 14 ):
@@ -282,7 +289,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32
282289 )
283290
284291 model = helper .make_model (graph , producer_name = "binary_test" )
285- check_correctness (model , opset = opset )
292+ check_correctness (model , opset = opset , check_dtypes = True )
286293
287294
288295def verify_compare (op_name , shape , attrs = {}, domain = None ):
@@ -1897,7 +1904,7 @@ def verify_constantofshape(input_dim, value, dtype):
18971904 ["input" ],
18981905 ["output" ],
18991906 value = helper .make_tensor (
1900- "value" , mapping . NP_TYPE_TO_TENSOR_TYPE [ np .dtype (dtype )] , (1 ,), (value ,)
1907+ "value" , helper . np_dtype_to_tensor_dtype ( np .dtype (dtype )) , (1 ,), (value ,)
19011908 ),
19021909 )
19031910
@@ -1917,7 +1924,7 @@ def verify_constantofshape(input_dim, value, dtype):
19171924 ],
19181925 outputs = [
19191926 helper .make_tensor_value_info (
1920- "output" , mapping . NP_TYPE_TO_TENSOR_TYPE [ np .dtype (dtype )] , input_dim
1927+ "output" , helper . np_dtype_to_tensor_dtype ( np .dtype (dtype )) , input_dim
19211928 )
19221929 ],
19231930 )
@@ -2299,7 +2306,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
22992306
23002307 inputs = [
23012308 helper .make_tensor_value_info (
2302- "input" , mapping . NP_TYPE_TO_TENSOR_TYPE [ indata .dtype ] , indata_shape
2309+ "input" , helper . np_dtype_to_tensor_dtype ( indata .dtype ) , indata_shape
23032310 )
23042311 ]
23052312
@@ -2333,7 +2340,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
23332340 outputs = [
23342341 helper .make_tensor_value_info (
23352342 f"output_{ i } " ,
2336- mapping . NP_TYPE_TO_TENSOR_TYPE [ indata .dtype ] ,
2343+ helper . np_dtype_to_tensor_dtype ( indata .dtype ) ,
23372344 list (outdata_shapes [i ]),
23382345 )
23392346 for i in range (len (split_index ))
0 commit comments