diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py index 4e18a50f8247..544f791293d4 100644 --- a/example/onnx/super_resolution.py +++ b/example/onnx/super_resolution.py @@ -34,7 +34,7 @@ def import_onnx(): """Import the onnx model into mxnet""" model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx' - download(model_url, 'super_resolution.onnx', version_tag = '"7348c879d16c42bc77e24e270f663524"') + download(model_url, 'super_resolution.onnx', version_tag='"7348c879d16c42bc77e24e270f663524"') LOGGER.info("Converting onnx format to mxnet's symbol and params...") sym, params = onnx_mxnet.import_model('super_resolution.onnx') @@ -46,14 +46,14 @@ def get_test_image(): # Load test image input_image_dim = 224 img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg' - download(img_url, 'super_res_input.jpg', version_tag = '"02c90a7248e51316b11f7f39dd1b226d"') + download(img_url, 'super_res_input.jpg', version_tag='"02c90a7248e51316b11f7f39dd1b226d"') img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim)) img_ycbcr = img.convert("YCbCr") img_y, img_cb, img_cr = img_ycbcr.split() input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :] return input_image, img_cb, img_cr -def perform_inference((sym, params), (input_img, img_cb, img_cr)): +def perform_inference(sym, params, input_img, img_cb, img_cr): """Perform inference on image using mxnet""" # create module mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None) @@ -79,4 +79,6 @@ def perform_inference((sym, params), (input_img, img_cb, img_cr)): return result_img if __name__ == '__main__': - perform_inference(import_onnx(), get_test_image()) + MX_SYM, MX_PARAM = import_onnx() + INPUT_IMG, IMG_CB, IMG_CR = get_test_image() + perform_inference(MX_SYM, MX_PARAM, INPUT_IMG, IMG_CB, IMG_CR) diff --git a/tests/python-pytest/onnx/onnx_test.py b/tests/python-pytest/onnx/onnx_test.py index ea1058090150..2693a19e97fe 100644 --- a/tests/python-pytest/onnx/onnx_test.py +++ b/tests/python-pytest/onnx/onnx_test.py @@ -110,9 +110,6 @@ def test_super_resolution(): assert len(sym.list_outputs()) == 1 assert sym.list_outputs()[0] == 'reshape5_output' - assert len(sym.list_attr()) == 1 - assert sym.list_attr()['shape'] == '(1L, 1L, 672L, 672L)' - attrs_keys = sym.attr_dict().keys() assert len(attrs_keys) == 19 for i, key_item in enumerate(['reshape4', 'param_5', 'param_4', 'param_7', @@ -121,18 +118,20 @@ def test_super_resolution(): 'reshape1', 'convolution2', 'convolution3', 'convolution0', 'convolution1', 'reshape5', 'transpose0']): - assert attrs_keys[i] == key_item + assert key_item in attrs_keys param_keys = params.keys() assert len(param_keys) == 8 for i, param_item in enumerate(['param_5', 'param_4', 'param_7', 'param_6', 'param_1', 'param_0', 'param_3', 'param_2']): - assert param_keys[i] == param_item + assert param_item in param_keys + LOGGER.info("Asserted the result of the onnx model conversion") output_img_dim = 672 - result_img = super_resolution.perform_inference((sym, params), - super_resolution.get_test_image()) + input_image, img_cb, img_cr = super_resolution.get_test_image() + result_img = super_resolution.perform_inference(sym, params, input_image, + img_cb, img_cr) assert result_img.size == (output_img_dim, output_img_dim) if __name__ == '__main__':