Skip to content

Commit 268fad6

Browse files
committed
[Fix][ONNX] No precision widening for numpy binary operations
1 parent 60f5568 commit 268fad6

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,17 @@ def base_impl(cls, bb, inputs, attr, params):
336336
"""Base implementation for binary operations."""
337337
if cls.numpy_op is None or cls.relax_op is None:
338338
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
339-
if all([isinstance(inp, relax.Constant) for inp in inputs]):
340-
output = cls.numpy_op( # pylint: disable=not-callable
341-
inputs[0].data.numpy(), inputs[1].data.numpy()
342-
)
343-
return relax.const(output, inputs[0].struct_info.dtype)
344-
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
339+
if all([not isinstance(inp, (relax.expr.Call, relax.Var)) for inp in inputs]):
345340
x = _to_numpy(inputs[0])
346341
y = _to_numpy(inputs[1])
347-
return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable
342+
output = cls.numpy_op(x, y) # pylint: disable=not-callable
343+
if x.dtype == y.dtype:
344+
# no numpy precision widening
345+
output = output.astype(x.dtype)
346+
if all([isinstance(inp, relax.Constant) for inp in inputs]):
347+
return relax.const(output, output.dtype) # pylint: disable=not-callable
348+
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
349+
return relax.PrimValue(output.item()) # pylint: disable=not-callable
348350

349351
return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable
350352

tests/python/relax/test_frontend_onnx.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import onnx
2828
import onnxruntime
2929
import pytest
30-
from onnx import ModelProto, TensorProto, helper, mapping
30+
from onnx import ModelProto, TensorProto, helper
3131

3232
import tvm
3333
import tvm.testing
@@ -62,7 +62,7 @@ def generate_random_inputs(
6262
def generate_random_value(shape, elem_type) -> np.ndarray:
6363
# Extract datatype for the input.
6464
if elem_type:
65-
dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
65+
dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
6666
else:
6767
dtype = "float32"
6868

@@ -87,6 +87,7 @@ def check_correctness(
8787
opset: int = 14,
8888
rtol: float = 1e-7,
8989
atol: float = 1e-5,
90+
check_dtypes: bool = False,
9091
) -> None:
9192
"""Run an onnx model in both onnxruntime and TVM through our importer
9293
confirm that the results match. Otherwise, an exception will be raised.
@@ -104,6 +105,8 @@ def check_correctness(
104105
atol: float
105106
Set the tolerance of correctness checking. Some ops may be show more
106107
arithmetic variance than others.
108+
check_dtypes: bool
109+
Check if data types are the same.
107110
"""
108111
# Configure model format.
109112
if ir_version is not None:
@@ -152,17 +155,35 @@ def check_correctness(
152155
# while the ONNX output number is one, which is a list
153156
tvm_output = [tvm_output]
154157

158+
def _get_numpy_subdtype(narray):
159+
if np.issubdtype(narray.dtype, np.integer):
160+
return "integer"
161+
elif np.issubdtype(narray.dtype, np.floating):
162+
return "floating"
163+
elif np.issubdtype(narray.dtype, np.bool_):
164+
return "bool"
165+
elif np.issubdtype(narray.dtype, np.complexfloating):
166+
return "complexfloating"
167+
else:
168+
return "other"
169+
155170
def _check_output(tvm_out, ort_out):
156171
if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)):
157172
assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
158173
for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
159174
_check_output(tvm_out_i, ort_out_i)
160175
elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
176+
if check_dtypes:
177+
assert tvm_out.numpy().dtype == ort_out.dtype
161178
tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
162179
elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
163180
shape_out = tvm.nd.array([int(i) for i in tvm_out])
181+
if check_dtypes:
182+
assert _get_numpy_subdtype(shape_out.numpy()) == _get_numpy_subdtype(ort_out)
164183
tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
165184
elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
185+
if check_dtypes:
186+
assert _get_numpy_subdtype(np.array(tvm_out)) == _get_numpy_subdtype(ort_out)
166187
tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
167188
else:
168189
raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")
@@ -267,7 +288,7 @@ def verify_binary(
267288
)
268289

269290
model = helper.make_model(graph, producer_name="binary_test")
270-
check_correctness(model, opset=opset)
291+
check_correctness(model, opset=opset, check_dtypes=True)
271292

272293

273294
def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32, opset=14):
@@ -282,7 +303,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32
282303
)
283304

284305
model = helper.make_model(graph, producer_name="binary_test")
285-
check_correctness(model, opset=opset)
306+
check_correctness(model, opset=opset, check_dtypes=True)
286307

287308

288309
def verify_compare(op_name, shape, attrs={}, domain=None):
@@ -1897,7 +1918,7 @@ def verify_constantofshape(input_dim, value, dtype):
18971918
["input"],
18981919
["output"],
18991920
value=helper.make_tensor(
1900-
"value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,)
1921+
"value", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), (1,), (value,)
19011922
),
19021923
)
19031924

@@ -1917,7 +1938,7 @@ def verify_constantofshape(input_dim, value, dtype):
19171938
],
19181939
outputs=[
19191940
helper.make_tensor_value_info(
1920-
"output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], input_dim
1941+
"output", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), input_dim
19211942
)
19221943
],
19231944
)
@@ -2299,7 +2320,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
22992320

23002321
inputs = [
23012322
helper.make_tensor_value_info(
2302-
"input", mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype], indata_shape
2323+
"input", helper.np_dtype_to_tensor_dtype(indata.dtype), indata_shape
23032324
)
23042325
]
23052326

@@ -2333,7 +2354,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
23332354
outputs=[
23342355
helper.make_tensor_value_info(
23352356
f"output_{i}",
2336-
mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
2357+
helper.np_dtype_to_tensor_dtype(indata.dtype),
23372358
list(outdata_shapes[i]),
23382359
)
23392360
for i in range(len(split_index))

0 commit comments

Comments
 (0)