Skip to content

Commit 605c745

Browse files
committed
[Fix][ONNX] No precision widening for numpy binary operations
1 parent 60f5568 commit 605c745

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,17 @@ def base_impl(cls, bb, inputs, attr, params):
336336
"""Base implementation for binary operations."""
337337
if cls.numpy_op is None or cls.relax_op is None:
338338
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
339-
if all([isinstance(inp, relax.Constant) for inp in inputs]):
340-
output = cls.numpy_op( # pylint: disable=not-callable
341-
inputs[0].data.numpy(), inputs[1].data.numpy()
342-
)
343-
return relax.const(output, inputs[0].struct_info.dtype)
344-
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
339+
if all([not isinstance(inp, (relax.expr.Call, relax.Var)) for inp in inputs]):
345340
x = _to_numpy(inputs[0])
346341
y = _to_numpy(inputs[1])
347-
return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable
342+
output = cls.numpy_op(x, y) # pylint: disable=not-callable
343+
if x.dtype == y.dtype:
344+
# no numpy precision widening
345+
output = output.astype(x.dtype)
346+
if all([isinstance(inp, relax.Constant) for inp in inputs]):
347+
return relax.const(output, output.dtype) # pylint: disable=not-callable
348+
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
349+
return relax.PrimValue(output.item()) # pylint: disable=not-callable
348350

349351
return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable
350352

tests/python/relax/test_frontend_onnx.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import onnx
2828
import onnxruntime
2929
import pytest
30-
from onnx import ModelProto, TensorProto, helper, mapping
30+
from onnx import ModelProto, TensorProto, helper
3131

3232
import tvm
3333
import tvm.testing
@@ -62,7 +62,7 @@ def generate_random_inputs(
6262
def generate_random_value(shape, elem_type) -> np.ndarray:
6363
# Extract datatype for the input.
6464
if elem_type:
65-
dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
65+
dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
6666
else:
6767
dtype = "float32"
6868

@@ -87,6 +87,7 @@ def check_correctness(
8787
opset: int = 14,
8888
rtol: float = 1e-7,
8989
atol: float = 1e-5,
90+
check_dtypes = False
9091
) -> None:
9192
"""Run an onnx model in both onnxruntime and TVM through our importer
9293
confirm that the results match. Otherwise, an exception will be raised.
@@ -104,6 +105,8 @@ def check_correctness(
104105
atol: float
105106
Set the tolerance of correctness checking. Some ops may be show more
106107
arithmetic variance than others.
108+
check_dtypes: bool
109+
Check if data types are the same.
107110
"""
108111
# Configure model format.
109112
if ir_version is not None:
@@ -158,11 +161,15 @@ def _check_output(tvm_out, ort_out):
158161
for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
159162
_check_output(tvm_out_i, ort_out_i)
160163
elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
164+
if check_dtypes:
165+
assert _get_numpy_subdtype(tvm_out.numpy()) == _get_numpy_subdtype(ort_out)
161166
tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
162167
elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
163168
shape_out = tvm.nd.array([int(i) for i in tvm_out])
164169
tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
165170
elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
171+
if check_dtypes:
172+
assert _get_numpy_subdtype(np.array(tvm_out)) == _get_numpy_subdtype(ort_out)
166173
tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
167174
else:
168175
raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")
@@ -267,7 +274,7 @@ def verify_binary(
267274
)
268275

269276
model = helper.make_model(graph, producer_name="binary_test")
270-
check_correctness(model, opset=opset)
277+
check_correctness(model, opset=opset, check_dtypes=True)
271278

272279

273280
def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32, opset=14):
@@ -282,7 +289,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32
282289
)
283290

284291
model = helper.make_model(graph, producer_name="binary_test")
285-
check_correctness(model, opset=opset)
292+
check_correctness(model, opset=opset, check_dtypes=True)
286293

287294

288295
def verify_compare(op_name, shape, attrs={}, domain=None):
@@ -1897,7 +1904,7 @@ def verify_constantofshape(input_dim, value, dtype):
18971904
["input"],
18981905
["output"],
18991906
value=helper.make_tensor(
1900-
"value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,)
1907+
"value", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), (1,), (value,)
19011908
),
19021909
)
19031910

@@ -1917,7 +1924,7 @@ def verify_constantofshape(input_dim, value, dtype):
19171924
],
19181925
outputs=[
19191926
helper.make_tensor_value_info(
1920-
"output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], input_dim
1927+
"output", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), input_dim
19211928
)
19221929
],
19231930
)
@@ -2299,7 +2306,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
22992306

23002307
inputs = [
23012308
helper.make_tensor_value_info(
2302-
"input", mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype], indata_shape
2309+
"input", helper.np_dtype_to_tensor_dtype(indata.dtype), indata_shape
23032310
)
23042311
]
23052312

@@ -2333,7 +2340,7 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
23332340
outputs=[
23342341
helper.make_tensor_value_info(
23352342
f"output_{i}",
2336-
mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
2343+
helper.np_dtype_to_tensor_dtype(indata.dtype),
23372344
list(outdata_shapes[i]),
23382345
)
23392346
for i in range(len(split_index))

0 commit comments

Comments
 (0)