Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix method arguments for Python3.5+
Browse files Browse the repository at this point in the history
Signed-off-by: Acharya <[email protected]>
  • Loading branch information
Acharya committed Mar 13, 2018
1 parent 1d02490 commit c04881d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
12 changes: 7 additions & 5 deletions example/onnx/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
def import_onnx():
"""Import the onnx model into mxnet"""
model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'
download(model_url, 'super_resolution.onnx', version_tag = '"7348c879d16c42bc77e24e270f663524"')
download(model_url, 'super_resolution.onnx', version_tag='"7348c879d16c42bc77e24e270f663524"')

LOGGER.info("Converting onnx format to mxnet's symbol and params...")
sym, params = onnx_mxnet.import_model('super_resolution.onnx')
Expand All @@ -46,15 +46,15 @@ def get_test_image():
# Load test image
input_image_dim = 224
img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
download(img_url, 'super_res_input.jpg', version_tag = '"02c90a7248e51316b11f7f39dd1b226d"')
download(img_url, 'super_res_input.jpg', version_tag='"02c90a7248e51316b11f7f39dd1b226d"')
img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim))
img_ycbcr = img.convert("YCbCr")
img_y, img_cb, img_cr = img_ycbcr.split()
input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
return input_image, img_cb, img_cr

def perform_inference((sym, params), (input_img, img_cb, img_cr)):
"""Perform inference on image using mxnet"""
def perform_inference(sym, params, input_img, img_cb, img_cr):
"""Perform inference on image using mxnet."""
# create module
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)])
Expand All @@ -79,4 +79,6 @@ def perform_inference((sym, params), (input_img, img_cb, img_cr)):
return result_img

if __name__ == '__main__':
perform_inference(import_onnx(), get_test_image())
MX_SYM, MX_PARAM = import_onnx()
INPUT_IMG, IMG_CB, IMG_CR = get_test_image()
perform_inference(MX_SYM, MX_PARAM, INPUT_IMG, IMG_CB, IMG_CR)
11 changes: 5 additions & 6 deletions tests/python-pytest/onnx/onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def test_super_resolution():
assert len(sym.list_outputs()) == 1
assert sym.list_outputs()[0] == 'reshape5_output'

assert len(sym.list_attr()) == 1
assert sym.list_attr()['shape'] == '(1L, 1L, 672L, 672L)'

attrs_keys = sym.attr_dict().keys()
assert len(attrs_keys) == 19
for i, key_item in enumerate(['reshape4', 'param_5', 'param_4', 'param_7',
Expand All @@ -131,9 +128,11 @@ def test_super_resolution():
LOGGER.info("Asserted the result of the onnx model conversion")

output_img_dim = 672
result_img = super_resolution.perform_inference((sym, params),
super_resolution.get_test_image())
input_image, img_cb, img_cr = super_resolution.get_test_image()
result_img = super_resolution.perform_inference(sym, params, input_image,
img_cb, img_cr)
assert result_img.size == (output_img_dim, output_img_dim)

if __name__ == '__main__':
unittest.main()
test_super_resolution()
#unittest.main()

0 comments on commit c04881d

Please sign in to comment.