Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
tests for maxroipool, randomnormal, randomuniform
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 28, 2018
1 parent 1ed28b6 commit d8de0f4
Showing 1 changed file with 76 additions and 21 deletions.
97 changes: 76 additions & 21 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@
import os
import unittest
import logging
import tarfile
from collections import namedtuple
import numpy as np
import numpy.testing as npt
from onnx import numpy_helper, helper, load_model
from onnx import TensorProto
from mxnet.test_utils import download
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx
import backend
Expand All @@ -56,6 +54,24 @@ def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32):
return np.random.choice(a=[False, True], size=shape).astype(np.float32)


def _fix_attributes(attrs, attribute_mapping):
new_attrs = attrs
attr_modify = attribute_mapping.get('modify', {})
for k, v in attr_modify.items():
new_attrs[v] = new_attrs.pop(k, None)

attr_add = attribute_mapping.get('add', {})
for k, v in attr_add.items():
new_attrs[k] = v

attr_remove = attribute_mapping.get('remove', [])
for k in attr_remove:
if k in new_attrs:
del new_attrs[k]

return new_attrs


def forward_pass(sym, arg, aux, data_names, input_data):
""" Perform forward pass on given data
:param sym: Symbol
Expand Down Expand Up @@ -118,46 +134,85 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
return model

for test in test_cases:
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific = test
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test
with self.subTest(test_name):
names, input_tensors, inputsym = get_input_tensors(inputs)
test_op = mxnet_op(*inputsym, **attrs)
mxnet_output = forward_pass(test_op, None, None, names, inputs)
outputshape = np.shape(mxnet_output)
if inputs:
test_op = mxnet_op(*inputsym, **attrs)
mxnet_output = forward_pass(test_op, None, None, names, inputs)
outputshape = np.shape(mxnet_output)
else:
test_op = mxnet_op(**attrs)
shape = attrs.get('shape', (1,))
x = mx.nd.zeros(shape, dtype='float32')
xgrad = mx.nd.zeros(shape, dtype='float32')
exe = test_op.bind(ctx=mx.cpu(), args={'x': x}, args_grad={'x': xgrad})
mxnet_output = exe.forward(is_train=False)[0].asnumpy()
outputshape = np.shape(mxnet_output)

if mxnet_specific:
onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs],
np.float32,
onnx_name + ".onnx")
onnxmodel = load_model(onnxmodelfile)
else:
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, attrs)
onnx_attrs = _fix_attributes(attrs, fix_attrs)
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, onnx_attrs)

bkd_rep = backend.prepare(onnxmodel, operation='export')
output = bkd_rep.run(inputs)

npt.assert_almost_equal(output[0], mxnet_output)
if check_value:
npt.assert_almost_equal(output[0], mxnet_output)

if check_shape:
npt.assert_equal(output[0].shape, outputshape)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False)
# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
# 'remove': [attr_name],
# 'add': {attr_name: value},
# check_value=True/False, check_shape=True/False)
test_cases = [
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_and", mx.sym.broadcast_logical_and, "And",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True,
False),
("test_xor", mx.sym.broadcast_logical_xor, "Xor",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True,
False),
("test_or", mx.sym.broadcast_logical_or, "Or",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True,
False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True,
False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True, {}, True,
False),
("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 1, 4, 6))],
{'block_size': 2}, False),
{'block_size': 2}, False, {}, True,
False),
("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)],
{'ignore_label': 0, 'use_ignore': False}, True),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True)
{'ignore_label': 0, 'use_ignore': False}, True, {}, True,
False),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True, {}, True,
False),
("test_roipool", mx.sym.ROIPooling, "MaxRoiPool",
[[[get_rnd(shape=(8, 6), low=1, high=100, dtype=np.int32)]], [[0, 0, 0, 4, 4]]],
{'pooled_size': (2, 2), 'spatial_scale': 0.7}, False,
{'modify': {'pooled_size': 'pooled_shape'}}, True, False),

# since results would be random, checking for shape alone
("test_random_normal", mx.sym.random_normal, "RandomNormal", [],
{'shape': (2, 2), 'loc': 0, 'scale': 1}, False, {'modify': {'loc': 'mean'}}, False, True),
("test_random_uniform", mx.sym.random_uniform, "RandomUniform", [],
{'shape': (2, 2), 'low': 0.5, 'high': 1.0}, False, {}, False, True)
]

if __name__ == '__main__':
Expand Down

0 comments on commit d8de0f4

Please sign in to comment.