From c8ee99ace5431ca8300d1aa7e3380c7e54a7d43a Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Fri, 5 Feb 2021 17:48:59 +0000 Subject: [PATCH] Replace deprecated ConstantFill onnx operator with Identity, add commented out models that work with onnx export but fail with onnxruntime. --- .../contrib/onnx/mx2onnx/_op_translations.py | 8 ++++---- tests/python-pytest/onnx/test_onnxruntime.py | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3576242e0d77..a9a63950e5de 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -783,10 +783,10 @@ def convert_copy(node, **kwargs): @mx_op.register("identity") def convert_identity(node, **kwargs): - """Map MXNet's identity operator attributes to onnx's ConstantFill operator + """Map MXNet's identity operator attributes to onnx's Identity operator and return the created node. """ - return create_basic_op_node('ConstantFill', node, kwargs) + return create_basic_op_node('Identity', node, kwargs) @mx_op.register("InstanceNorm") def convert_instancenorm(node, **kwargs): @@ -1013,12 +1013,12 @@ def convert_logistic_regression_output(node, **kwargs): @mx_op.register("BlockGrad") def convert_blockgrad(node, **kwargs): """ Skip operator """ - return create_basic_op_node('ConstantFill', node, kwargs) + return create_basic_op_node('Identity', node, kwargs) @mx_op.register("MakeLoss") def convert_makeloss(node, **kwargs): """ Skip operator """ - return create_basic_op_node('ConstantFill', node, kwargs) + return create_basic_op_node('Identity', node, kwargs) @mx_op.register("Concat") def convert_concat(node, **kwargs): diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 151c5bebd9b2..319d65348bb8 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -246,7 +246,24 @@ def normalize_image(imgfile): 'center_net_resnet101_v1b_voc', 'center_net_resnet18_v1b_coco', 'center_net_resnet50_v1b_coco', - 'center_net_resnet101_v1b_coco' + 'center_net_resnet101_v1b_coco', + # the following models are failing due to onnxruntime errors + #'ssd_300_vgg16_atrous_voc', + #'ssd_512_vgg16_atrous_voc', + #'ssd_512_resnet50_v1_voc', + #'ssd_512_mobilenet1.0_voc', + #'faster_rcnn_resnet50_v1b_voc', + #'yolo3_darknet53_voc', + #'yolo3_mobilenet1.0_voc', + #'ssd_300_vgg16_atrous_coco', + #'ssd_512_vgg16_atrous_coco', + #'ssd_300_resnet34_v1b_coco', + #'ssd_512_resnet50_v1_coco', + #'ssd_512_mobilenet1.0_coco', + #'faster_rcnn_resnet50_v1b_coco', + #'faster_rcnn_resnet101_v1d_coco', + #'yolo3_darknet53_coco', + #'yolo3_mobilenet1.0_coco', ]) def test_obj_detection_model_inference_onnxruntime(tmp_path, model): def normalize_image(imgfile):