Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Replace deprecated ConstantFill onnx operator with Identity, add comm…
Browse files Browse the repository at this point in the history
…ented out models that work with onnx export but fail with onnxruntime.
  • Loading branch information
Joe Evans committed Feb 5, 2021
1 parent f85fc62 commit c8ee99a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
8 changes: 4 additions & 4 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,10 @@ def convert_copy(node, **kwargs):

@mx_op.register("identity")
def convert_identity(node, **kwargs):
"""Map MXNet's identity operator attributes to onnx's ConstantFill operator
"""Map MXNet's identity operator attributes to onnx's Identity operator
and return the created node.
"""
return create_basic_op_node('ConstantFill', node, kwargs)
return create_basic_op_node('Identity', node, kwargs)

@mx_op.register("InstanceNorm")
def convert_instancenorm(node, **kwargs):
Expand Down Expand Up @@ -1013,12 +1013,12 @@ def convert_logistic_regression_output(node, **kwargs):
@mx_op.register("BlockGrad")
def convert_blockgrad(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)
return create_basic_op_node('Identity', node, kwargs)

@mx_op.register("MakeLoss")
def convert_makeloss(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)
return create_basic_op_node('Identity', node, kwargs)

@mx_op.register("Concat")
def convert_concat(node, **kwargs):
Expand Down
19 changes: 18 additions & 1 deletion tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,24 @@ def normalize_image(imgfile):
'center_net_resnet101_v1b_voc',
'center_net_resnet18_v1b_coco',
'center_net_resnet50_v1b_coco',
'center_net_resnet101_v1b_coco'
'center_net_resnet101_v1b_coco',
# the following models are failing due to onnxruntime errors
#'ssd_300_vgg16_atrous_voc',
#'ssd_512_vgg16_atrous_voc',
#'ssd_512_resnet50_v1_voc',
#'ssd_512_mobilenet1.0_voc',
#'faster_rcnn_resnet50_v1b_voc',
#'yolo3_darknet53_voc',
#'yolo3_mobilenet1.0_voc',
#'ssd_300_vgg16_atrous_coco',
#'ssd_512_vgg16_atrous_coco',
#'ssd_300_resnet34_v1b_coco',
#'ssd_512_resnet50_v1_coco',
#'ssd_512_mobilenet1.0_coco',
#'faster_rcnn_resnet50_v1b_coco',
#'faster_rcnn_resnet101_v1d_coco',
#'yolo3_darknet53_coco',
#'yolo3_mobilenet1.0_coco',
])
def test_obj_detection_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
Expand Down

0 comments on commit c8ee99a

Please sign in to comment.