diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index 903b0cd1c51f..ad76f2d4c79c 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -50,6 +50,7 @@ import logging import json +import numpy as np from mxnet import ndarray as nd @@ -276,7 +277,7 @@ def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=Fals class NodeOutput: def __init__(self, name, dtype): self.name = name - self.dtype = dtype + self.dtype = np.dtype(dtype) initializer = [] all_processed_nodes = []