Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import/export: Sample_multinomial
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 11, 2018
1 parent 97e0c97 commit a4d88ef
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 1 deletion.
22 changes: 22 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,3 +1655,25 @@ def convert_size(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Size', node, kwargs)

@mx_op.register("_sample_multinomial")
def convert_multinomial(node, **kwargs):
"""Map MXNet's multinomial operator attributes to onnx's
Multinomial operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(attrs.get("dtype", 'int32'))]
sample_size = convert_string_to_list(attrs.get("shape", '1'))
if len(sample_size) < 2:
sample_size = sample_size[-1]
else:
raise AttributeError("ONNX currently supports integer sample_size only")
node = onnx.helper.make_node(
"Multinomial",
input_nodes,
[name],
dtype=dtype,
sample_size=sample_size,
name=name,
)
return [node]
3 changes: 2 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# coding: utf-8_
# pylint: disable=invalid-name
"""Operator attributes conversion"""
from ._op_translations import identity, random_uniform, random_normal
from ._op_translations import identity, random_uniform, random_normal, sample_multinomial
from ._op_translations import add, subtract, multiply, divide, absolute, negative, add_n
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
Expand Down Expand Up @@ -48,6 +48,7 @@
'RandomNormal' : random_normal,
'RandomUniformLike' : random_uniform,
'RandomNormalLike' : random_normal,
'Multinomial' : sample_multinomial,
# Arithmetic Operators
'Add' : add,
'Sub' : subtract,
Expand Down
13 changes: 13 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def random_normal(attrs, inputs, proto_obj):
new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 'loc'})
return 'random_uniform', new_attr, inputs

def sample_multinomial(attrs, inputs, proto_obj):
"""Draw random samples from a multinomial distribution."""
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
new_attrs = translation_utils._remove_attributes(attrs, ['seed'])
new_attrs = translation_utils._fix_attribute_names(new_attrs, {'sample_size': 'shape'})
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(new_attrs['dtype'])]
return 'sample_multinomial', new_attrs, inputs


# Arithmetic Operations
def add(attrs, inputs, proto_obj):
"""Adding two tensors"""
Expand Down
16 changes: 16 additions & 0 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,22 @@ def test_softmax():
# Comparing result of forward pass before using onnx export, import
npt.assert_almost_equal(result, softmax_out)

@with_seed()
def test_multinomial():
input1 = np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")
shape = (10,)
inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=np.shape(input1))]
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=shape)]
nodes = [helper.make_node("Multinomial", ["input1"], ["output"], sample_size=shape, dtype=6)]
graph = helper.make_graph(nodes,
"multinomial_test",
inputs,
outputs)
multinomial_model = helper.make_model(graph)
bkd_rep = backend.prepare(multinomial_model)
output = bkd_rep.run([input1])
assert output[0].shape == shape

@with_seed()
def test_comparison_ops():
"""Test greater, lesser, equal"""
Expand Down

0 comments on commit a4d88ef

Please sign in to comment.