Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-898] ONNX import/export: Sample_multinomial, ONNX export: GlobalLpPool, LpPool #13500

Merged
merged 5 commits into from
Jan 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 66 additions & 17 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def convert_pooling(node, **kwargs):
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
global_pool = get_boolean_attribute_value(attrs, "global_pool")
p_value = attrs.get('p_value', 'None')

pooling_convention = attrs.get('pooling_convention', 'valid')

Expand All @@ -598,26 +599,51 @@ def convert_pooling(node, **kwargs):

pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
pad_dims = pad_dims + pad_dims
pool_types = {"max": "MaxPool", "avg": "AveragePool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"}
pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool",
"lp": "GlobalLpPool"}

if pool_type == 'lp' and p_value == 'None':
raise AttributeError('ONNX requires a p value for LpPool and GlobalLpPool')

if global_pool:
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
name=name
)
if pool_type == 'lp':
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
p=int(p_value),
name=name
)
else:
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
name=name
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
if pool_type == 'lp':
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
p=int(p_value),
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)

return [node]

Expand Down Expand Up @@ -1689,3 +1715,26 @@ def convert_logsoftmax(node, **kwargs):
name=name
)
return [node]


@mx_op.register("_sample_multinomial")
def convert_multinomial(node, **kwargs):
"""Map MXNet's multinomial operator attributes to onnx's
Multinomial operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(attrs.get("dtype", 'int32'))]
sample_size = convert_string_to_list(attrs.get("shape", '1'))
vandanavk marked this conversation as resolved.
Show resolved Hide resolved
if len(sample_size) < 2:
sample_size = sample_size[-1]
else:
raise AttributeError("ONNX currently supports integer sample_size only")
node = onnx.helper.make_node(
"Multinomial",
input_nodes,
[name],
dtype=dtype,
sample_size=sample_size,
name=name,
)
return [node]
3 changes: 2 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# coding: utf-8_
# pylint: disable=invalid-name
"""Operator attributes conversion"""
from ._op_translations import identity, random_uniform, random_normal
from ._op_translations import identity, random_uniform, random_normal, sample_multinomial
from ._op_translations import add, subtract, multiply, divide, absolute, negative, add_n
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
Expand Down Expand Up @@ -48,6 +48,7 @@
'RandomNormal' : random_normal,
'RandomUniformLike' : random_uniform,
'RandomNormalLike' : random_normal,
'Multinomial' : sample_multinomial,
# Arithmetic Operators
'Add' : add,
'Sub' : subtract,
Expand Down
21 changes: 18 additions & 3 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def random_normal(attrs, inputs, proto_obj):
new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 'loc'})
return 'random_uniform', new_attr, inputs

def sample_multinomial(attrs, inputs, proto_obj):
"""Draw random samples from a multinomial distribution."""
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
new_attrs = translation_utils._remove_attributes(attrs, ['seed'])
new_attrs = translation_utils._fix_attribute_names(new_attrs, {'sample_size': 'shape'})
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))]
return 'sample_multinomial', new_attrs, inputs


# Arithmetic Operations
def add(attrs, inputs, proto_obj):
"""Adding two tensors"""
Expand Down Expand Up @@ -382,6 +395,7 @@ def global_lppooling(attrs, inputs, proto_obj):
'kernel': (1, 1),
'pool_type': 'lp',
'p_value': p_value})
new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
return 'Pooling', new_attrs, inputs

def linalg_gemm(attrs, inputs, proto_obj):
Expand Down Expand Up @@ -671,11 +685,12 @@ def lp_pooling(attrs, inputs, proto_obj):
new_attrs = translation_utils._fix_attribute_names(attrs,
{'kernel_shape': 'kernel',
'strides': 'stride',
'pads': 'pad',
'p_value': p_value
'pads': 'pad'
})
new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'pooling_convention': 'valid'
{'pooling_convention': 'valid',
'p_value': p_value
})
new_op = translation_utils._fix_pooling('lp', inputs, new_attrs)
return new_op, new_attrs, inputs
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _fix_pooling(pool_type, inputs, new_attr):
stride = new_attr.get('stride')
kernel = new_attr.get('kernel')
padding = new_attr.get('pad')
p_value = new_attr.get('p_value')

# Adding default stride.
if stride is None:
Expand Down Expand Up @@ -138,7 +139,10 @@ def _fix_pooling(pool_type, inputs, new_attr):
new_pad_op = symbol.pad(curr_sym, mode='constant', pad_width=pad_width)

# Apply pooling without pads.
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel)
if pool_type == 'lp':
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel, p_value=p_value)
else:
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel)
return new_pooling_op

def _fix_bias(op_name, attrs, num_inputs):
Expand Down
2 changes: 0 additions & 2 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
'test_softplus'
],
'import': ['test_gather',
'test_global_lppooling',
'test_softsign',
'test_reduce_',
'test_mean',
Expand All @@ -89,7 +88,6 @@
'test_averagepool_2d_precomputed_strides',
'test_averagepool_2d_strides',
'test_averagepool_3d',
'test_LpPool_',
'test_split_equal',
'test_hardmax'
],
Expand Down
83 changes: 67 additions & 16 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32):
return np.random.choice(a=[False, True], size=shape).astype(np.float32)


def _fix_attributes(attrs, attribute_mapping):
new_attrs = attrs
attr_modify = attribute_mapping.get('modify', {})
for k, v in attr_modify.items():
new_attrs[v] = new_attrs.pop(k, None)

attr_add = attribute_mapping.get('add', {})
for k, v in attr_add.items():
new_attrs[k] = v

attr_remove = attribute_mapping.get('remove', [])
for k in attr_remove:
if k in new_attrs:
del new_attrs[k]

return new_attrs


def forward_pass(sym, arg, aux, data_names, input_data):
""" Perform forward pass on given data
:param sym: Symbol
Expand Down Expand Up @@ -118,7 +136,7 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
return model

for test in test_cases:
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific = test
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test
with self.subTest(test_name):
names, input_tensors, inputsym = get_input_tensors(inputs)
test_op = mxnet_op(*inputsym, **attrs)
Expand All @@ -131,33 +149,66 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
onnx_name + ".onnx")
onnxmodel = load_model(onnxmodelfile)
else:
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, attrs)
onnx_attrs = _fix_attributes(attrs, fix_attrs)
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, onnx_attrs)

bkd_rep = backend.prepare(onnxmodel, operation='export')
output = bkd_rep.run(inputs)

npt.assert_almost_equal(output[0], mxnet_output)
if check_value:
npt.assert_almost_equal(output[0], mxnet_output)

if check_shape:
npt.assert_equal(output[0].shape, outputshape)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False)
# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
# 'remove': [attr_name],
# 'add': {attr_name: value},
# check_value=True/False, check_shape=True/False)
test_cases = [
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_and", mx.sym.broadcast_logical_and, "And",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_xor", mx.sym.broadcast_logical_xor, "Xor",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_or", mx.sym.broadcast_logical_or, "Or",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True, {}, True, False),
("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 1, 4, 6))],
{'block_size': 2}, False),
{'block_size': 2}, False, {}, True, False),
("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)],
{'ignore_label': 0, 'use_ignore': False}, True),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True)
{'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp'}, False,
{'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
'remove': ['pool_type']}, True, False),
("test_lppool2", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp'}, False,
{'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
'remove': ['pool_type']}, True, False),
("test_globallppool1", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp', 'global_pool': True}, False,
{'modify': {'p_value': 'p'},
'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),
("test_globallppool2", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp', 'global_pool': True}, False,
{'modify': {'p_value': 'p'},
'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),

# since results would be random, checking for shape alone
("test_multinomial", mx.sym.sample_multinomial, "Multinomial",
[np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")],
{'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True)
]

if __name__ == '__main__':
Expand Down