Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import/export: Test for square operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 15, 2018
1 parent 33404ab commit e59bda1
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,22 @@
from __future__ import absolute_import
import sys
import os
import unittest
import logging
import tarfile
from collections import namedtuple
import numpy as np
import numpy.testing as npt
from onnx import numpy_helper
from onnx import numpy_helper, helper
from onnx import TensorProto
from mxnet.test_utils import download
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx
CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest'))
sys.path.insert(0, os.path.join(CURR_PATH, '../../../python/unittest'))
import backend
from common import with_seed

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
URLS = {
Expand Down Expand Up @@ -179,6 +183,29 @@ def test_model_accuracy(model_name, input_shape):
npt.assert_equal(expected.shape, actual.shape)
npt.assert_almost_equal(expected, actual, decimal=3)

@with_seed()
def test_square():
input1 = np.random.randint(1, 10, (2, 3)).astype("float32")

ipsym = mx.sym.Variable("input1")
square = mx.sym.square(data=ipsym)
model = mx.mod.Module(symbol=square, data_names=['input1'], label_names=None)
model.bind(for_training=False, data_shapes=[('input1', np.shape(input1))], label_shapes=None)
model.init_params()

args, auxs = model.get_params()
params = {}
params.update(args)
params.update(auxs)

converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)

numpy_op = np.square(input1)

npt.assert_almost_equal(result, numpy_op)

if __name__ == '__main__':
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
Expand All @@ -189,3 +216,5 @@ def test_model_accuracy(model_name, input_shape):
# ONNX expected results due to AveragePool issue github issue(#10194)
test_model_accuracy("inception_v1", (1, 3, 224, 224))
test_model_accuracy("inception_v2", (1, 3, 224, 224))

unittest.main()

0 comments on commit e59bda1

Please sign in to comment.